Further optimize multi-lora inference,LoRA-enabled performance achieves 80%+ of non-LoRA performance (#190)
* optimize lora inference Signed-off-by: wanghao <wanghao@example.com> * further optimize multi-lora inference,LoRA-enabled performance achieves 80%+ of non-LoRA performance Signed-off-by: wanghao <wanghao@example.com> --------- Signed-off-by: wanghao <wanghao@example.com> Co-authored-by: wanghao <wanghao@example.com>
This commit is contained in:
@@ -1,86 +1,94 @@
|
||||
"""kunlun_ops for lora"""
|
||||
|
||||
|
||||
import torch
|
||||
import xspeedgate_ops
|
||||
import time
|
||||
from torch._C import dtype
|
||||
import os
|
||||
from torch._dynamo import disable
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
expert_m: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
"""
|
||||
sgmv_shrink
|
||||
"""
|
||||
|
||||
|
||||
return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling)
|
||||
|
||||
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
seq_len_tensor.to(torch.int32),
|
||||
lora_indices_tensor.to(torch.int32),
|
||||
output_tensor,
|
||||
scaling,
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
def sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
"""
|
||||
sgmv_expand
|
||||
"""
|
||||
|
||||
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0)
|
||||
|
||||
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
seq_len_tensor.to(torch.int32),
|
||||
lora_indices_tensor.to(torch.int32),
|
||||
output_tensor,
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
normed_scale: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
|
||||
def sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
normed_scale: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
"""
|
||||
sgmv_expand_slice
|
||||
"""
|
||||
|
||||
|
||||
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset)
|
||||
|
||||
|
||||
|
||||
|
||||
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
seq_len_tensor.to(torch.int32),
|
||||
lora_indices_tensor.to(torch.int32),
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
@@ -92,27 +100,33 @@ def bgmv_shrink(
|
||||
moe_index: torch.Tensor,
|
||||
expert_m: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor, # [m]
|
||||
scaling: float = 1.0
|
||||
scaling: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
bgmv_shrink
|
||||
"""
|
||||
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling)
|
||||
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
""""
|
||||
bgmv_expand
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
block_statistic: torch.Tensor,
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
moe_index: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
""" "
|
||||
bgmv_expand
|
||||
"""
|
||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0)
|
||||
# @my_wrapper
|
||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
@@ -125,9 +139,11 @@ def bgmv_expand_slice(
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
bgmv_expand_slice
|
||||
bgmv_expand_slice
|
||||
"""
|
||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset)
|
||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user