### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -16,11 +16,13 @@
|
||||
import torch
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
return torch.ops._C_ascend.bgmv_shrink(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
@@ -30,11 +32,13 @@ def bgmv_shrink(inputs: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
@@ -45,16 +49,18 @@ def bgmv_expand(inputs: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset, slice_size
|
||||
)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@@ -69,21 +75,23 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
return torch.ops._C_ascend.sgmv_shrink(
|
||||
inputs, lora_a_weights, lora_indices_tensor, seq_len_tensor, output_tensor, scaling
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
def sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
@@ -95,19 +103,20 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset,
|
||||
slice_size)
|
||||
def sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs, lora_b_weights, lora_indices_tensor, seq_len_tensor, output_tensor, slice_offset, slice_size
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
@@ -18,26 +18,30 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||
refresh_all_lora_classes()
|
||||
self.lora_config = kwargs.get("lora_config")
|
||||
if get_ascend_device_type() == AscendDeviceType._310P or (
|
||||
self.lora_config is not None
|
||||
and self.lora_config.max_lora_rank >= 128):
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
self.lora_config is not None and self.lora_config.max_lora_rank >= 128
|
||||
):
|
||||
from vllm.lora.ops.torch_ops import (
|
||||
bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
else:
|
||||
from vllm_ascend.lora.lora_ops import (bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
from vllm_ascend.lora.lora_ops import (
|
||||
bgmv_expand,
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
self.bgmv_expand = bgmv_expand
|
||||
self.bgmv_expand_slice = bgmv_expand_slice
|
||||
self.bgmv_shrink = bgmv_shrink
|
||||
@@ -52,7 +56,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
self.sgmv_shrink(
|
||||
@@ -79,7 +83,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
self.sgmv_expand(
|
||||
@@ -108,7 +112,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
# No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
self.sgmv_expand_slice(
|
||||
@@ -130,8 +134,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices,
|
||||
y_offset, y_slice_size, add_inputs)
|
||||
self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_expand(
|
||||
self,
|
||||
@@ -148,13 +151,10 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self._expand_slice_decode)
|
||||
expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
|
||||
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, scale: float):
|
||||
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
@@ -165,14 +165,18 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
shrink_fun: Callable = (self._shrink_prefill
|
||||
if self.is_prefill else self._shrink_decode)
|
||||
shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float, **kwargs):
|
||||
def add_shrink(
|
||||
self,
|
||||
y: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Performs GEMM for multiple slices of lora_a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
@@ -194,18 +198,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(lora_a_stacked)):
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||
scale)
|
||||
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
|
||||
|
||||
def add_expand(self,
|
||||
y: torch.Tensor,
|
||||
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||
output_slices: Tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs) -> None:
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: tuple[torch.Tensor, ...] | None,
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
@@ -229,8 +234,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||
lora_bias_stacked)
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
@@ -243,12 +247,9 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
offset_left += output_slices[slice_idx]
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_embedding(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
**kwargs) -> None:
|
||||
def add_lora_embedding(
|
||||
self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||
|
||||
@@ -263,21 +264,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
|
||||
# Embedding layer only need expand op
|
||||
expand_fun: Callable = (self._expand_prefill
|
||||
if self.is_prefill else self._expand_decode)
|
||||
expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode
|
||||
x = x.to(torch.float32)
|
||||
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||
|
||||
def add_lora_linear(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...],
|
||||
*,
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
def add_lora_linear(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
scale: float,
|
||||
output_slices: tuple[int, ...],
|
||||
*,
|
||||
buffer: tuple[torch.Tensor, ...] | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
|
||||
@@ -308,27 +310,22 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
# We set the buffer to be float32 by default, consistent with the
|
||||
# triton op
|
||||
buffer = tuple(
|
||||
torch.zeros(
|
||||
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
for _ in range(len(output_slices)))
|
||||
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
|
||||
)
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y,
|
||||
buffer,
|
||||
lora_b_stacked,
|
||||
None,
|
||||
output_slices,
|
||||
add_inputs=True,
|
||||
**kwargs)
|
||||
self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> None:
|
||||
def add_lora_logits(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||
|
||||
@@ -350,9 +347,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
|
||||
|
||||
@@ -1,91 +1,75 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
from vllm.lora.layers import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
from vllm_ascend.ops.linear import (
|
||||
AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear,
|
||||
)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
@@ -93,18 +77,28 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return (type(source_layer) is AscendQKVParallelLinear
|
||||
and len(packed_modules_list) == 3)
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
|
||||
def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
|
||||
|
||||
Reference in New Issue
Block a user