### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
import vllm
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
from vllm.config import LoRAConfig
|
|
from vllm.lora.layers import (
|
|
ColumnParallelLinearWithLoRA,
|
|
MergedColumnParallelLinearWithLoRA,
|
|
MergedQKVParallelLinearWithLoRA,
|
|
QKVParallelLinearWithLoRA,
|
|
RowParallelLinearWithLoRA,
|
|
VocabParallelEmbeddingWithLoRA,
|
|
)
|
|
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
|
|
|
from vllm_ascend.ops.linear import (
|
|
AscendColumnParallelLinear,
|
|
AscendMergedColumnParallelLinear,
|
|
AscendQKVParallelLinear,
|
|
AscendRowParallelLinear,
|
|
)
|
|
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
|
|
|
|
|
|
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
return type(source_layer) is AscendColumnParallelLinear
|
|
|
|
|
|
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
return type(source_layer) is AscendMergedColumnParallelLinear
|
|
|
|
|
|
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
return type(source_layer) is AscendRowParallelLinear
|
|
|
|
|
|
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
return type(source_layer) is AscendVocabParallelEmbedding
|
|
|
|
|
|
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
|
@classmethod
|
|
@_not_fully_sharded_can_replace
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
|
|
|
|
|
|
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
|
@classmethod
|
|
@_not_fully_sharded_can_replace
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
|
|
|
|
|
|
def refresh_all_lora_classes():
|
|
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
|
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
|
|
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
|
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
|
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
|
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
|