Files
xc-llm-ascend/vllm_ascend/lora/utils.py
yupeng 29f195a91c [Bugfix][LoRA] Fix the bug when runs Qwen3-Reranker-0.6B with LoRA. (#7156)
### What this PR does / why we need it?
Fix the error that reports while initializing qwen3-reranker-0.6b model
with `--enable-lora`.
And add a testcase to verify the fix.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
2026-03-15 17:55:42 +08:00

200 lines
6.9 KiB
Python
Executable File

import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear,
)
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendRowParallelLinear
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
class AscendReplicatedLinearWithLoRA(ReplicatedLinearWithLoRA):
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendReplicatedLinear
class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendMergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
class AscendQKVParallelLinearWithShardedLoRA(QKVParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
class AscendRowParallelLinearWithShardedLoRA(RowParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendRowParallelLinear
def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendReplicatedLinearWithLoRA)