[Misc] Move lora patch file into lora module (#2797)

Cleanup useless file in patch module. Update the lora support list is OK
in vLLM Ascend, no need to patch vLLM


- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-09-08 21:42:12 +08:00
committed by GitHub
parent 85d989a3b9
commit 7d6d9449a8
10 changed files with 64 additions and 72 deletions

View File

@@ -11,12 +11,14 @@ if is_310p():
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
else:
from vllm_ascend.lora.punica_wrapper.lora_ops import (
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm_ascend.lora.utils import refresh_all_lora_classes
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
@@ -31,6 +33,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
refresh_all_lora_classes()
def _shrink_prefill(
self,

77
vllm_ascend/lora/utils.py Normal file
View File

@@ -0,0 +1,77 @@
from typing import Optional
import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithLoRA(
MergedColumnParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendRowParallelLinear
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(
AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)