[Bugfix][LoRA] Fix LoRA bug after supporting Qwen3-Next (#3044)

### What this PR does / why we need it?
LoRA e2e test uses ilama-3.2-1B model. It uses transformers.py model
files. Its self-attention layer names end with "\*.attn", not
"\*.self_attn".

There are some other model attention layer names end with "*.attn", such
as baichuan.py, bert.py.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py

- vLLM version: v0.10.2
- vLLM main:
17b4c6685c

---------

Signed-off-by: paulyu12 <507435917@qq.com>
This commit is contained in:
yupeng
2025-09-26 11:12:45 +08:00
committed by GitHub
parent 8406aafaff
commit 9caf6fbaf5
2 changed files with 35 additions and 2 deletions

View File

@@ -92,7 +92,7 @@ jobs:
pytest -sv tests/e2e/singlecard/test_chunked.py pytest -sv tests/e2e/singlecard/test_chunked.py
pytest -sv tests/e2e/singlecard/test_embedding.py pytest -sv tests/e2e/singlecard/test_embedding.py
pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py
#pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py
pytest -sv tests/e2e/singlecard/test_quantization.py pytest -sv tests/e2e/singlecard/test_quantization.py
pytest -sv tests/e2e/singlecard/test_sampler.py pytest -sv tests/e2e/singlecard/test_sampler.py
@@ -174,7 +174,7 @@ jobs:
# external_launcher test is not stable enough. Fix it later # external_launcher test is not stable enough. Fix it later
# pytest -sv tests/e2e/multicard/test_external_launcher.py # pytest -sv tests/e2e/multicard/test_external_launcher.py
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
#pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
# To avoid oom, we need to run the test in a single process. # To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ

View File

@@ -6,11 +6,15 @@ from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA, from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA, RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA) VocabParallelEmbeddingWithLoRA)
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (AscendColumnParallelLinear, from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendRowParallelLinear) AscendRowParallelLinear)
from vllm_ascend.ops.vocab_parallel_embedding import \ from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding AscendVocabParallelEmbedding
@@ -69,9 +73,38 @@ class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
return type(source_layer) is AscendVocabParallelEmbedding return type(source_layer) is AscendVocabParallelEmbedding
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(
packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is AscendQKVParallelLinear
and len(packed_modules_list) == 3)
def refresh_all_lora_classes(): def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add( vllm.lora.utils._all_lora_classes.add(
AscendMergedColumnParallelLinearWithLoRA) AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(
AscendMergedQKVParallelLinearWithLoRA)