[Feature] totaly support multi-lora support,latest xspeedgate needed (#133)
Co-authored-by: wanghao <wanghao@example.com>
This commit is contained in:
0
vllm_kunlun/lora/__init__.py
Normal file
0
vllm_kunlun/lora/__init__.py
Normal file
0
vllm_kunlun/lora/ops/__init__.py
Normal file
0
vllm_kunlun/lora/ops/__init__.py
Normal file
16
vllm_kunlun/lora/ops/kunlun_ops/__init__.py
Normal file
16
vllm_kunlun/lora/ops/kunlun_ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""
|
||||||
|
|
||||||
|
from vllm_kunlun.lora.ops.kunlun_ops.lora_ops import (bgmv_expand,bgmv_expand_slice, bgmv_shrink,
|
||||||
|
sgmv_expand, sgmv_expand_slice,
|
||||||
|
sgmv_shrink)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"bgmv_expand",
|
||||||
|
"bgmv_expand_slice",
|
||||||
|
"bgmv_shrink",
|
||||||
|
"sgmv_expand",
|
||||||
|
"sgmv_expand_slice",
|
||||||
|
"sgmv_shrink"
|
||||||
|
]
|
||||||
133
vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py
Normal file
133
vllm_kunlun/lora/ops/kunlun_ops/lora_ops.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""kunlun_ops for lora"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import xspeedgate_ops
|
||||||
|
import time
|
||||||
|
from torch._C import dtype
|
||||||
|
import os
|
||||||
|
from torch._dynamo import disable
|
||||||
|
|
||||||
|
|
||||||
|
def sgmv_shrink(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
sgmv_shrink
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def sgmv_expand(inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False):
|
||||||
|
"""
|
||||||
|
sgmv_expand
|
||||||
|
"""
|
||||||
|
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False):
|
||||||
|
|
||||||
|
"""
|
||||||
|
sgmv_expand_slice
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def bgmv_shrink(
|
||||||
|
inputs: torch.Tensor, # [m, hidden_dim]
|
||||||
|
lora_a_weights: torch.Tensor, # [n, 1, r, hidden_dim]
|
||||||
|
output_tensor: torch.Tensor, # [m, r]
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor, # [m]
|
||||||
|
scaling: float = 1.0
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
bgmv_shrink
|
||||||
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling)
|
||||||
|
|
||||||
|
|
||||||
|
def bgmv_expand(inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True):
|
||||||
|
""""
|
||||||
|
bgmv_expand
|
||||||
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0)
|
||||||
|
# @my_wrapper
|
||||||
|
|
||||||
|
def bgmv_expand_slice(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
bgmv_expand_slice
|
||||||
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset)
|
||||||
0
vllm_kunlun/lora/punica_wrapper/__init__.py
Normal file
0
vllm_kunlun/lora/punica_wrapper/__init__.py
Normal file
548
vllm_kunlun/lora/punica_wrapper/punica_kunlun.py
Normal file
548
vllm_kunlun/lora/punica_wrapper/punica_kunlun.py
Normal file
@@ -0,0 +1,548 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
||||||
|
# Author: Wang Hao
|
||||||
|
# Email: wanghao129@baidu.com
|
||||||
|
# This file is a part of the vllm-kunlun project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Based on:
|
||||||
|
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||||
|
Punica: Multi-Tenant LoRA Serving.
|
||||||
|
https://arxiv.org/abs/2310.18547
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union, final
|
||||||
|
|
||||||
|
import torch
|
||||||
|
# Disable torchdynamo for all functions in this file
|
||||||
|
torch._dynamo.config.disable = True
|
||||||
|
|
||||||
|
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
from vllm_kunlun.lora.ops.kunlun_ops import (
|
||||||
|
bgmv_expand,
|
||||||
|
bgmv_expand_slice,
|
||||||
|
bgmv_shrink,
|
||||||
|
sgmv_expand,
|
||||||
|
sgmv_expand_slice,
|
||||||
|
sgmv_shrink,
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
|
|
||||||
|
# The platforms that are compatible with the PyTorch-native implementation can
|
||||||
|
# inherit this class
|
||||||
|
class PunicaWrapperKunlun(PunicaWrapperBase):
|
||||||
|
"""
|
||||||
|
PunicaWrapperKunlun with moe_fc
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
|
max_batches: int,
|
||||||
|
device: Union[torch.device, str],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
|
||||||
|
|
||||||
|
def _shrink_prefill(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
|
||||||
|
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
|
||||||
|
|
||||||
|
sgmv_shrink(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
expert_m,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shrink_decode(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
|
||||||
|
expert_m = torch.zeros(9, dtype=torch.int32, device=x.device)
|
||||||
|
bgmv_shrink(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
expert_m,
|
||||||
|
self.token_lora_indices,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_prefill(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
sgmv_expand(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_decode(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
bgmv_expand(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
self.token_lora_indices,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_slice_prefill(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
sgmv_expand_slice(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
normed_scale,
|
||||||
|
*self.prefill_metadata,
|
||||||
|
y_offset,
|
||||||
|
y_slice_size,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _expand_slice_decode(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
add_inputs: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
normed_scale = torch.ones([y.size(0), 1], dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
bgmv_expand_slice(
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
y,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
normed_scale,
|
||||||
|
self.token_lora_indices,
|
||||||
|
y_offset,
|
||||||
|
y_slice_size,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_expand(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
y_offset: int,
|
||||||
|
y_slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||||
|
computation, which is suitable for the
|
||||||
|
GEMM of lora'b.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expand_slice_fun: Callable = (
|
||||||
|
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
|
||||||
|
)
|
||||||
|
expand_slice_fun(
|
||||||
|
y,
|
||||||
|
x,
|
||||||
|
w_t_all,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
y_offset,
|
||||||
|
y_slice_size,
|
||||||
|
add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_shrink(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w_t_all: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||||
|
GEMM of lora'a.
|
||||||
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
|
prefill stage, and the `_shrink_prefill` function should be called.
|
||||||
|
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||||
|
should be called.
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
|
||||||
|
shrink_fun: Callable = (
|
||||||
|
self._shrink_prefill if self.is_prefill else self._shrink_decode
|
||||||
|
)
|
||||||
|
|
||||||
|
shrink_fun(
|
||||||
|
y, x, w_t_all, block_statistic, sorted_tokens_num_lod, moe_index, scale
|
||||||
|
)
|
||||||
|
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_shrink(
|
||||||
|
self,
|
||||||
|
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Performs GEMM for multiple slices of lora_a.
|
||||||
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
|
prefill stage, and the `_shrink_prefill` function should be called.
|
||||||
|
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||||
|
should be called.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_a_stacked)):
|
||||||
|
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||||
|
scale (float): Scaling factor for the operation
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
for slice_idx in range(len(lora_a_stacked)): # Each slice represents a layer
|
||||||
|
|
||||||
|
self._apply_shrink(
|
||||||
|
y[slice_idx],
|
||||||
|
x,
|
||||||
|
lora_a_stacked[slice_idx],
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_expand(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
offset_start: int = 0,
|
||||||
|
add_inputs=True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_b_stacked)):
|
||||||
|
slice = output_slices[i]
|
||||||
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||||
|
lora_bias_stacked[i]
|
||||||
|
offset += slice
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||||
|
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||||
|
bias's weight
|
||||||
|
output_slices (Tuple[int, ...]): Every slice's size
|
||||||
|
add_inputs (bool): Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
offset_left = offset_start
|
||||||
|
|
||||||
|
if lora_bias_stacked is not None:
|
||||||
|
self._apply_bias(
|
||||||
|
self.token_lora_indices, y, output_slices, lora_bias_stacked
|
||||||
|
)
|
||||||
|
|
||||||
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
|
self._apply_expand(
|
||||||
|
y,
|
||||||
|
x[slice_idx],
|
||||||
|
lora_b_stacked[slice_idx],
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
offset_left,
|
||||||
|
output_slices[slice_idx],
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
offset_left += output_slices[slice_idx]
|
||||||
|
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_lora_embedding(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
y += x @ lora_b_stacked
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||||
|
add_inputs (bool): Default to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
expand_fun: Callable = (
|
||||||
|
self._expand_prefill if self.is_prefill else self._expand_decode
|
||||||
|
)
|
||||||
|
expand_fun(y, x, lora_b_stacked, add_inputs)
|
||||||
|
|
||||||
|
def add_lora_linear(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
scale: float,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
*,
|
||||||
|
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Applicable to linear-related lora.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_a_stacked)):
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
|
* scale
|
||||||
|
).squeeze(0)+lora_bias_stacked[i]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
|
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||||
|
scale (float): Scaling factor.
|
||||||
|
output_slices (Tuple[int, ...]): Every slice's size.
|
||||||
|
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.no_lora:
|
||||||
|
return
|
||||||
|
|
||||||
|
expert_num = 9
|
||||||
|
block_statistic = torch.zeros(
|
||||||
|
[12, expert_num], dtype=torch.int32, device=x.device
|
||||||
|
)
|
||||||
|
sorted_tokens_num_lod = torch.zeros(
|
||||||
|
expert_num + 1, dtype=torch.int32, device=x.device
|
||||||
|
)
|
||||||
|
token_nums = x.size(0)
|
||||||
|
moe_index = torch.zeros(token_nums, dtype=torch.int32, device=x.device)
|
||||||
|
|
||||||
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
|
if lora_bias_stacked is not None:
|
||||||
|
assert len(lora_bias_stacked) == len(output_slices)
|
||||||
|
y = self._apply_bias(
|
||||||
|
self.token_lora_indices, y, output_slices, lora_bias_stacked
|
||||||
|
)
|
||||||
|
|
||||||
|
if buffer is None:
|
||||||
|
r = lora_b_stacked[0].size(-1)
|
||||||
|
buffer = tuple(
|
||||||
|
torch.zeros((x.size(0), r), dtype=torch.float16, device=x.device)
|
||||||
|
for _ in range(len(output_slices))
|
||||||
|
)
|
||||||
|
# [tensor.squeeze_(1) for tensor in lora_a_stacked]
|
||||||
|
new_lora_a_stacked = tuple(lora_a.squeeze(1) for lora_a in lora_a_stacked)
|
||||||
|
self.add_shrink(
|
||||||
|
buffer,
|
||||||
|
x,
|
||||||
|
new_lora_a_stacked,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
scale,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# [tensor.unsqueeze_(1) for tensor in lora_a_stacked]
|
||||||
|
|
||||||
|
# [tensor.squeeze_(1) for tensor in lora_b_stacked]
|
||||||
|
new_lora_b_stacked = tuple(lora_b.squeeze(1) for lora_b in lora_b_stacked)
|
||||||
|
self.add_expand(
|
||||||
|
y,
|
||||||
|
buffer,
|
||||||
|
new_lora_b_stacked,
|
||||||
|
block_statistic,
|
||||||
|
sorted_tokens_num_lod,
|
||||||
|
moe_index,
|
||||||
|
None,
|
||||||
|
output_slices,
|
||||||
|
add_inputs=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# [tensor.unsqueeze_(1) for tensor in lora_b_stacked]
|
||||||
|
|
||||||
|
def add_lora_logits(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
scale,
|
||||||
|
*,
|
||||||
|
buffer: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
buffer = (x @ lora_a_stacked) * scale
|
||||||
|
y += buffer @ lora_b_stacked
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||||
|
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||||
|
scale (float): Scaling factor.
|
||||||
|
buffer (Optional[torch.Tensor]):Default to None.
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
if lora_a_stacked.dim() == 2:
|
||||||
|
lora_a_stacked = lora_a_stacked.unsqueeze(0)
|
||||||
|
if lora_b_stacked.dim() == 2:
|
||||||
|
lora_b_stacked = lora_b_stacked.unsqueeze(0)
|
||||||
|
|
||||||
|
r = lora_a_stacked.size(-1)
|
||||||
|
|
||||||
|
if buffer is None:
|
||||||
|
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||||
|
|
||||||
|
indices = self.sampler_indices
|
||||||
|
if indices.max() >= lora_a_stacked.size(0):
|
||||||
|
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
|
||||||
|
|
||||||
|
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
|
||||||
|
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
|
||||||
|
|
||||||
|
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||||
|
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||||
|
|
||||||
|
y = y.view_as(y_org)
|
||||||
@@ -311,7 +311,10 @@ class KunlunPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls):
|
def get_punica_wrapper(cls):
|
||||||
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
'''
|
||||||
|
kunlun wrapper
|
||||||
|
'''
|
||||||
|
return "vllm_kunlun.lora.punica_wrapper.punica_kunlun.PunicaWrapperKunlun"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||||
|
|||||||
Reference in New Issue
Block a user