2025-09-07 10:31:32 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
|
2025-09-08 21:42:12 +08:00
|
|
|
import vllm
|
2025-09-07 10:31:32 +08:00
|
|
|
from torch import nn
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
from vllm.config import LoRAConfig
|
|
|
|
|
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|
|
|
|
MergedColumnParallelLinearWithLoRA,
|
2025-09-08 21:42:12 +08:00
|
|
|
RowParallelLinearWithLoRA,
|
|
|
|
|
VocabParallelEmbeddingWithLoRA)
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
|
|
|
|
AscendMergedColumnParallelLinear,
|
|
|
|
|
AscendRowParallelLinear)
|
2025-09-08 21:42:12 +08:00
|
|
|
from vllm_ascend.ops.vocab_parallel_embedding import \
|
|
|
|
|
AscendVocabParallelEmbedding
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
|
2025-09-08 21:42:12 +08:00
|
|
|
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def can_replace_layer(
|
|
|
|
|
cls,
|
|
|
|
|
source_layer: nn.Module,
|
|
|
|
|
lora_config: LoRAConfig,
|
|
|
|
|
packed_modules_list: list,
|
|
|
|
|
model_config: Optional[PretrainedConfig],
|
|
|
|
|
) -> bool:
|
2025-09-08 21:42:12 +08:00
|
|
|
return type(source_layer) is AscendColumnParallelLinear
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
|
2025-09-08 21:42:12 +08:00
|
|
|
class AscendMergedColumnParallelLinearWithLoRA(
|
|
|
|
|
MergedColumnParallelLinearWithLoRA):
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def can_replace_layer(
|
|
|
|
|
cls,
|
|
|
|
|
source_layer: nn.Module,
|
|
|
|
|
lora_config: LoRAConfig,
|
|
|
|
|
packed_modules_list: list,
|
|
|
|
|
model_config: Optional[PretrainedConfig],
|
|
|
|
|
) -> bool:
|
2025-09-08 21:42:12 +08:00
|
|
|
return type(source_layer) is AscendMergedColumnParallelLinear
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
|
2025-09-08 21:42:12 +08:00
|
|
|
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
2025-09-07 10:31:32 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def can_replace_layer(
|
|
|
|
|
cls,
|
|
|
|
|
source_layer: nn.Module,
|
|
|
|
|
lora_config: LoRAConfig,
|
|
|
|
|
packed_modules_list: list,
|
|
|
|
|
model_config: Optional[PretrainedConfig],
|
|
|
|
|
) -> bool:
|
2025-09-08 21:42:12 +08:00
|
|
|
return type(source_layer) is AscendRowParallelLinear
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def can_replace_layer(
|
|
|
|
|
cls,
|
|
|
|
|
source_layer: nn.Module,
|
|
|
|
|
lora_config: LoRAConfig,
|
|
|
|
|
packed_modules_list: list,
|
|
|
|
|
model_config: Optional[PretrainedConfig],
|
|
|
|
|
) -> bool:
|
|
|
|
|
return type(source_layer) is AscendVocabParallelEmbedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def refresh_all_lora_classes():
|
|
|
|
|
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
|
|
|
|
vllm.lora.utils._all_lora_classes.add(
|
|
|
|
|
AscendMergedColumnParallelLinearWithLoRA)
|
|
|
|
|
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
|
|
|
|
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|