Files
xc-llm-ascend/vllm_ascend/lora/utils.py
yupeng 830f39dd70 [Bugfix][LoRA] Fix the issue when enable LoRA + tp + fully_sharded_loras (#6650)
### What this PR does / why we need it?
Fix the issue #6143 .

### Does this PR introduce _any_ user-facing change?
Allow to start the server with "--enable-lora && --fully-sharded-loras
&& --tensor_parallel_size 2".

### How was this patch tested?
pytest -sv tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2026-03-11 15:43:15 +08:00

183 lines
6.2 KiB
Python
Executable File

import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendRowParallelLinear,
)
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendRowParallelLinear
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendMergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
class AscendQKVParallelLinearWithShardedLoRA(QKVParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
class AscendRowParallelLinearWithShardedLoRA(RowParallelLinearWithShardedLoRA):
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is AscendRowParallelLinear
def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA)