[Bugfix][LoRA] Fix the bug when runs Qwen3-Reranker-0.6B with LoRA. (#7156)

### What this PR does / why we need it?
Fix the error that reports while initializing qwen3-reranker-0.6b model
with `--enable-lora`.
And add a testcase to verify the fix.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
yupeng
2026-03-15 17:55:42 +08:00
committed by GitHub
parent 7daccf4b64
commit 29f195a91c
6 changed files with 108 additions and 4 deletions

View File

@@ -15,12 +15,14 @@ from vllm.lora.layers import (
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear,
)
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
@@ -103,6 +105,20 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
class AscendReplicatedLinearWithLoRA(ReplicatedLinearWithLoRA):
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendReplicatedLinear
class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
@@ -180,3 +196,4 @@ def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendReplicatedLinearWithLoRA)

View File

@@ -433,7 +433,7 @@ class AscendReplicatedLinear(ReplicatedLinear):
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op = get_replicated_op(disable_tp, prefix, self)
self.custom_op, self.tp_rank, self.tp_size = get_replicated_op(disable_tp, prefix, self)
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes

View File

@@ -734,11 +734,12 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix, layer) -> CustomReplicatedOp | None:
def get_replicated_op(disable_tp, prefix, layer) -> tuple[CustomReplicatedOp | None, int | None, int | None]:
if disable_tp:
return None
return None, None, None
return CustomReplicatedOp(layer)
custom_op = CustomReplicatedOp(layer)
return custom_op, custom_op.tp_rank, custom_op.tp_size
def is_moe_layer(prefix: str) -> bool: