import vllm from torch import nn from transformers import PretrainedConfig from vllm.config import LoRAConfig from vllm.lora.layers import ( ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithLoRA, MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithLoRA, QKVParallelLinearWithShardedLoRA, RowParallelLinearWithLoRA, RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA, ) from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace from vllm_ascend.ops.linear import ( AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, AscendReplicatedLinear, AscendRowParallelLinear, ) from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @classmethod @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendColumnParallelLinear class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): @classmethod @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendMergedColumnParallelLinear class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA): @classmethod @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendRowParallelLinear class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendVocabParallelEmbedding class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA): @classmethod @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1 class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): @classmethod @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3 class AscendReplicatedLinearWithLoRA(ReplicatedLinearWithLoRA): # ReplicatedLinear should always be replaced, regardless of the fully # sharded LoRAs setting, because it is, by definition, copied per GPU. @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: return type(source_layer) is AscendReplicatedLinear class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA): @classmethod @_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: return type(source_layer) is AscendColumnParallelLinear class AscendMergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithShardedLoRA): @classmethod @_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: return type(source_layer) is AscendMergedColumnParallelLinear class AscendMergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithShardedLoRA): @classmethod @_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3 class AscendQKVParallelLinearWithShardedLoRA(QKVParallelLinearWithShardedLoRA): @classmethod @_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1 class AscendRowParallelLinearWithShardedLoRA(RowParallelLinearWithShardedLoRA): @classmethod @_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: return type(source_layer) is AscendRowParallelLinear def refresh_all_lora_classes(): vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA) vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA) vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA) vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA) vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA) vllm.lora.utils._all_lora_classes.add(AscendReplicatedLinearWithLoRA)