init v0.11.0rc0
This commit is contained in:
@@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
return torch.ops._C.bgmv_shrink(
|
||||
return torch.ops._C_ascend.bgmv_shrink(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
lora_indices_tensor,
|
||||
@@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(
|
||||
return torch.ops._C_ascend.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
@@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@@ -69,9 +69,9 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
@@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C.sgmv_expand(
|
||||
return torch.ops._C_ascend.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
@@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset, slice_size)
|
||||
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, seq_len_tensor,
|
||||
output_tensor, slice_offset,
|
||||
slice_size)
|
||||
@@ -11,12 +11,14 @@ if is_310p():
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
else:
|
||||
from vllm_ascend.lora.punica_wrapper.lora_ops import (
|
||||
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
|
||||
from vllm_ascend.lora.utils import refresh_all_lora_classes
|
||||
|
||||
|
||||
# The platforms that are compatible with the PyTorch-native implementation can
|
||||
# inherit this class
|
||||
@@ -31,6 +33,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
device: Union[torch.device, str], **kwargs):
|
||||
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
|
||||
device)
|
||||
refresh_all_lora_classes()
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
@@ -338,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
if lora_a_stacked.dim() == 2:
|
||||
lora_a_stacked = lora_a_stacked.unsqueeze(0)
|
||||
if lora_b_stacked.dim() == 2:
|
||||
lora_b_stacked = lora_b_stacked.unsqueeze(0)
|
||||
|
||||
r = lora_a_stacked.size(-1)
|
||||
r = lora_b_stacked.size(-1)
|
||||
|
||||
if buffer is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
@@ -352,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
device=x.device)
|
||||
|
||||
indices = self.sampler_indices
|
||||
if indices.max() >= lora_a_stacked.size(0):
|
||||
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
|
||||
|
||||
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
|
||||
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
|
||||
|
||||
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
|
||||
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
|
||||
|
||||
y = y.view_as(y_org)
|
||||
110
vllm_ascend/lora/utils.py
Normal file
110
vllm_ascend/lora/utils.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Optional
|
||||
|
||||
import vllm
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import \
|
||||
AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendVocabParallelEmbedding
|
||||
|
||||
|
||||
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return (type(source_layer) is AscendQKVParallelLinear
|
||||
and len(packed_modules_list) == 3)
|
||||
|
||||
|
||||
def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(
|
||||
AscendMergedQKVParallelLinearWithLoRA)
|
||||
Reference in New Issue
Block a user