[Bugfix][LoRA] Fix the issue when enable LoRA + tp + fully_sharded_loras (#6650)
### What this PR does / why we need it?
Fix the issue #6143 .
### Does this PR introduce _any_ user-facing change?
Allow to start the server with "--enable-lora && --fully-sharded-loras
&& --tensor_parallel_size 2".
### How was this patch tested?
pytest -sv tests/e2e/multicard/2-cards/test_llama32_lora_tp2.py
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: paulyu12 <507435917@qq.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
11
vllm_ascend/lora/punica_npu.py
Normal file → Executable file
11
vllm_ascend/lora/punica_npu.py
Normal file → Executable file
@@ -205,7 +205,6 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
y: torch.Tensor,
|
||||
x: tuple[torch.Tensor, ...] | torch.Tensor,
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
lora_bias_stacked: tuple[torch.Tensor, ...] | None,
|
||||
output_slices: tuple[int, ...],
|
||||
offset_start: int = 0,
|
||||
add_inputs=True,
|
||||
@@ -217,24 +216,20 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
|
||||
offset += slice
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
offset_start (int): The starting position of y, defaults to 0
|
||||
add_inputs (bool): Defaults to True.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = offset_start
|
||||
if lora_bias_stacked is not None:
|
||||
self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self._apply_expand(
|
||||
y,
|
||||
@@ -313,7 +308,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices))
|
||||
)
|
||||
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
|
||||
self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs)
|
||||
self.add_expand(y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs)
|
||||
|
||||
def add_lora_logits(
|
||||
self,
|
||||
|
||||
80
vllm_ascend/lora/utils.py
Normal file → Executable file
80
vllm_ascend/lora/utils.py
Normal file → Executable file
@@ -4,13 +4,18 @@ from transformers import PretrainedConfig
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.layers import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
VocabParallelEmbeddingWithLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
|
||||
from vllm.lora.layers.utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
|
||||
|
||||
from vllm_ascend.ops.linear import (
|
||||
AscendColumnParallelLinear,
|
||||
@@ -23,6 +28,7 @@ from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbeddin
|
||||
|
||||
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
@@ -35,6 +41,7 @@ class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
|
||||
class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
@@ -47,6 +54,7 @@ class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoR
|
||||
|
||||
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
@@ -95,6 +103,71 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
|
||||
class AscendColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithShardedLoRA):
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithShardedLoRA):
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendMergedColumnParallelLinear
|
||||
|
||||
|
||||
class AscendMergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithShardedLoRA):
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
|
||||
class AscendQKVParallelLinearWithShardedLoRA(QKVParallelLinearWithShardedLoRA):
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1
|
||||
|
||||
|
||||
class AscendRowParallelLinearWithShardedLoRA(RowParallelLinearWithShardedLoRA):
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is AscendRowParallelLinear
|
||||
|
||||
|
||||
def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA)
|
||||
@@ -102,3 +175,8 @@ def refresh_all_lora_classes():
|
||||
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithShardedLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithShardedLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithShardedLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithShardedLoRA)
|
||||
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithShardedLoRA)
|
||||
|
||||
Reference in New Issue
Block a user