Further optimize multi-lora inference,LoRA-enabled performance achieves 80%+ of non-LoRA performance (#190)
* optimize lora inference Signed-off-by: wanghao <wanghao@example.com> * further optimize multi-lora inference,LoRA-enabled performance achieves 80%+ of non-LoRA performance Signed-off-by: wanghao <wanghao@example.com> --------- Signed-off-by: wanghao <wanghao@example.com> Co-authored-by: wanghao <wanghao@example.com>
This commit is contained in:
@@ -1,11 +1,6 @@
|
|||||||
"""kunlun_ops for lora"""
|
"""kunlun_ops for lora"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import xspeedgate_ops
|
|
||||||
import time
|
|
||||||
from torch._C import dtype
|
|
||||||
import os
|
|
||||||
from torch._dynamo import disable
|
|
||||||
|
|
||||||
|
|
||||||
def sgmv_shrink(
|
def sgmv_shrink(
|
||||||
@@ -27,60 +22,73 @@ def sgmv_shrink(
|
|||||||
"""
|
"""
|
||||||
sgmv_shrink
|
sgmv_shrink
|
||||||
"""
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||||
|
inputs,
|
||||||
|
lora_a_weights,
|
||||||
|
seq_len_tensor.to(torch.int32),
|
||||||
|
lora_indices_tensor.to(torch.int32),
|
||||||
|
output_tensor,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling)
|
def sgmv_expand(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
def sgmv_expand(inputs: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
batches: int,
|
||||||
seq_len_tensor: torch.Tensor,
|
max_seq_length: int,
|
||||||
lora_indices_tensor: torch.Tensor,
|
token_nums: int,
|
||||||
batches: int,
|
add_inputs: bool = False,
|
||||||
max_seq_length: int,
|
):
|
||||||
token_nums: int,
|
|
||||||
add_inputs: bool = False):
|
|
||||||
"""
|
"""
|
||||||
sgmv_expand
|
sgmv_expand
|
||||||
"""
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0)
|
inputs,
|
||||||
|
lora_b_weights,
|
||||||
|
seq_len_tensor.to(torch.int32),
|
||||||
|
lora_indices_tensor.to(torch.int32),
|
||||||
|
output_tensor,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sgmv_expand_slice(
|
||||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
normed_scale: torch.Tensor,
|
normed_scale: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
token_nums: int,
|
token_nums: int,
|
||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = False):
|
add_inputs: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
sgmv_expand_slice
|
sgmv_expand_slice
|
||||||
"""
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
|
inputs,
|
||||||
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset)
|
lora_b_weights,
|
||||||
|
seq_len_tensor.to(torch.int32),
|
||||||
|
lora_indices_tensor.to(torch.int32),
|
||||||
|
output_tensor,
|
||||||
|
slice_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_shrink(
|
def bgmv_shrink(
|
||||||
@@ -92,27 +100,33 @@ def bgmv_shrink(
|
|||||||
moe_index: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
expert_m: torch.Tensor,
|
expert_m: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor, # [m]
|
lora_indices_tensor: torch.Tensor, # [m]
|
||||||
scaling: float = 1.0
|
scaling: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
bgmv_shrink
|
bgmv_shrink
|
||||||
"""
|
"""
|
||||||
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling)
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||||
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand(inputs: torch.Tensor,
|
def bgmv_expand(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
add_inputs: bool = True):
|
lora_indices_tensor: torch.Tensor,
|
||||||
""""
|
add_inputs: bool = True,
|
||||||
bgmv_expand
|
):
|
||||||
|
""" "
|
||||||
|
bgmv_expand
|
||||||
"""
|
"""
|
||||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0)
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
# @my_wrapper
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand_slice(
|
def bgmv_expand_slice(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
@@ -125,9 +139,11 @@ def bgmv_expand_slice(
|
|||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = True
|
add_inputs: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
bgmv_expand_slice
|
bgmv_expand_slice
|
||||||
"""
|
"""
|
||||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset)
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,16 +22,11 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union, final
|
|
||||||
|
|
||||||
import torch
|
|
||||||
# Disable torchdynamo for all functions in this file
|
|
||||||
torch._dynamo.config.disable = True
|
|
||||||
|
|
||||||
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
from vllm_kunlun.lora.ops.kunlun_ops import (
|
from vllm_kunlun.lora.ops.kunlun_ops import (
|
||||||
bgmv_expand,
|
bgmv_expand,
|
||||||
@@ -42,7 +37,7 @@ from vllm_kunlun.lora.ops.kunlun_ops import (
|
|||||||
sgmv_shrink,
|
sgmv_shrink,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
# Disable torchdynamo for all functions in this file
|
||||||
|
|
||||||
|
|
||||||
# The platforms that are compatible with the PyTorch-native implementation can
|
# The platforms that are compatible with the PyTorch-native implementation can
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
"""vllm_utils_wrapper.py"""
|
"""vllm_utils_wrapper.py"""
|
||||||
|
|
||||||
import vllm.distributed.parallel_state as parallel_state
|
|
||||||
import vllm.utils as _orig
|
|
||||||
from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple
|
|
||||||
from types import SimpleNamespace
|
|
||||||
import torch
|
|
||||||
from torch.library import Library
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import socket
|
||||||
import typing
|
import typing
|
||||||
from torch.library import register_fake
|
from types import SimpleNamespace
|
||||||
import vllm_kunlun._kunlun
|
from typing import Any, Callable, List, Optional, Tuple, Union, get_args, get_origin
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm.distributed.parallel_state as parallel_state
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.utils as _orig
|
||||||
|
from torch.library import Library, register_fake
|
||||||
|
|
||||||
|
|
||||||
def patch_annotations_for_schema(func):
|
def patch_annotations_for_schema(func):
|
||||||
@@ -87,7 +87,7 @@ def direct_register_custom_op(
|
|||||||
import torch.library
|
import torch.library
|
||||||
|
|
||||||
if hasattr(torch.library, "infer_schema"):
|
if hasattr(torch.library, "infer_schema"):
|
||||||
patched_func = patch_annotations_for_schema(op_func)
|
patch_annotations_for_schema(op_func)
|
||||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||||
else:
|
else:
|
||||||
# for pytorch 2.4
|
# for pytorch 2.4
|
||||||
@@ -153,7 +153,7 @@ _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
|
|||||||
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
|
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
|
||||||
_wrapped._get_open_port = _get_open_port
|
_wrapped._get_open_port = _get_open_port
|
||||||
|
|
||||||
import sys
|
import sys # noqa: E402
|
||||||
|
|
||||||
sys.modules["vllm.utils"] = _wrapped
|
sys.modules["vllm.utils"] = _wrapped
|
||||||
|
|
||||||
@@ -204,11 +204,10 @@ parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce
|
|||||||
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
|
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
|
||||||
|
|
||||||
|
|
||||||
from torch.library import custom_op, impl
|
from typing import Optional # noqa: E402
|
||||||
import torch
|
|
||||||
from vllm import _custom_ops as ops
|
import torch # noqa: E402
|
||||||
from typing import Optional, List
|
from torch.library import custom_op, impl # noqa: E402
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
@custom_op("_C::rms_norm", mutates_args=())
|
@custom_op("_C::rms_norm", mutates_args=())
|
||||||
@@ -379,9 +378,9 @@ def silu_and_mul_quant_xpu(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch # noqa: E402
|
||||||
import xtorch_ops
|
import xtorch_ops # noqa: E402
|
||||||
from torch.library import custom_op, impl
|
from torch.library import custom_op, impl # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@custom_op("_C::add_rmsnorm", mutates_args=())
|
@custom_op("_C::add_rmsnorm", mutates_args=())
|
||||||
@@ -472,7 +471,7 @@ def rmsnorm_cuda(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def _fake_rmsnorm(
|
def _fake_rmsnorm(
|
||||||
@@ -618,7 +617,6 @@ split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox)
|
|||||||
|
|
||||||
# register fake op impl here
|
# register fake op impl here
|
||||||
# for torch.dynamo
|
# for torch.dynamo
|
||||||
from torch.library import register_fake
|
|
||||||
|
|
||||||
if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
||||||
|
|
||||||
@@ -1396,7 +1394,7 @@ def awq_dequantize_cuda(
|
|||||||
device=qweight.device,
|
device=qweight.device,
|
||||||
)
|
)
|
||||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||||
out = xtorch_ops.awq_dequantize(
|
xtorch_ops.awq_dequantize(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
zeros=zeros,
|
zeros=zeros,
|
||||||
@@ -1915,7 +1913,7 @@ def apply_repetition_penalties_(
|
|||||||
|
|
||||||
|
|
||||||
@impl("_C::apply_repetition_penalties_", "CUDA")
|
@impl("_C::apply_repetition_penalties_", "CUDA")
|
||||||
def apply_repetition_penalties_(
|
def apply_repetition_penalties_cuda(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
prompt_mask: torch.Tensor,
|
prompt_mask: torch.Tensor,
|
||||||
output_mask: torch.Tensor,
|
output_mask: torch.Tensor,
|
||||||
@@ -2341,34 +2339,499 @@ dequant_int4.register_fake(_fake_dequant_int4)
|
|||||||
##################################################
|
##################################################
|
||||||
@custom_op("_C::fast_topkv2", mutates_args=())
|
@custom_op("_C::fast_topkv2", mutates_args=())
|
||||||
def fast_topkv2(
|
def fast_topkv2(
|
||||||
score: torch.Tensor,
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
lengths: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = xtorch_ops.fast_topkv2(
|
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||||
score=score,
|
|
||||||
lengths=lengths,
|
|
||||||
topk=topk)
|
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
@impl("_C::fast_topkv2", "CUDA")
|
@impl("_C::fast_topkv2", "CUDA")
|
||||||
def fast_topkv2_cuda(
|
def fast_topkv2_cuda(
|
||||||
score: torch.Tensor,
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
lengths: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = xtorch_ops.fast_topkv2(
|
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||||
score=score,
|
|
||||||
lengths=lengths,
|
|
||||||
topk=topk)
|
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
def _fake_fast_topkv2(
|
def _fake_fast_topkv2(
|
||||||
score: torch.Tensor,
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
lengths: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
fast_topkv2.register_fake(_fake_fast_topkv2)
|
fast_topkv2.register_fake(_fake_fast_topkv2)
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------------- LoRA ops --------------------
|
||||||
|
##################################################
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- sgmv_shrink_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sgmv_shrink_lora", mutates_args=())
|
||||||
|
def sgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
|
||||||
|
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||||
|
inputs,
|
||||||
|
lora_a_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::sgmv_shrink_lora", "CUDA")
|
||||||
|
def sgmv_shrink_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
|
||||||
|
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||||
|
inputs,
|
||||||
|
lora_a_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_sgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_shrink_lora.register_fake(_fake_sgmv_shrink_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- sgmv_expand_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sgmv_expand_lora", mutates_args=())
|
||||||
|
def sgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
|
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::sgmv_expand_lora", "CUDA")
|
||||||
|
def sgmv_expand_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
|
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_sgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_expand_lora.register_fake(_fake_sgmv_expand_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- sgmv_expand_slice_lora -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sgmv_expand_slice_lora", mutates_args=())
|
||||||
|
def sgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
inputs,
|
||||||
|
lora_b_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
slice_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::sgmv_expand_slice_lora", "CUDA")
|
||||||
|
def sgmv_expand_slice_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
inputs,
|
||||||
|
lora_b_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
slice_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_sgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_expand_slice_lora.register_fake(_fake_sgmv_expand_slice_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- bgmv_shrink_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::bgmv_shrink_lora", mutates_args=())
|
||||||
|
def bgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||||
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::bgmv_shrink_lora", "CUDA")
|
||||||
|
def bgmv_shrink_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||||
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_bgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_shrink_lora.register_fake(_fake_bgmv_shrink_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- bgmv_expand_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::bgmv_expand_lora", mutates_args=())
|
||||||
|
def bgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::bgmv_expand_lora", "CUDA")
|
||||||
|
def bgmv_expand_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_bgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_expand_lora.register_fake(_fake_bgmv_expand_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- bgmv_expand_slice_lora -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::bgmv_expand_slice_lora", mutates_args=())
|
||||||
|
def bgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::bgmv_expand_slice_lora", "CUDA")
|
||||||
|
def bgmv_expand_slice_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_bgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_expand_slice_lora.register_fake(_fake_bgmv_expand_slice_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- lora_matmul_inplace ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::lora_matmul_inplace", mutates_args=())
|
||||||
|
def lora_matmul_inplace(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
x_trans: bool = False,
|
||||||
|
w_trans: bool = True,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
beta: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.matmul(
|
||||||
|
x=x.contiguous(),
|
||||||
|
w=w.contiguous(),
|
||||||
|
out=output_tensor,
|
||||||
|
x_trans=x_trans,
|
||||||
|
w_trans=w_trans,
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::lora_matmul_inplace", "CUDA")
|
||||||
|
def lora_matmul_inplace_cuda(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
x_trans: bool = False,
|
||||||
|
w_trans: bool = True,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
beta: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.matmul(
|
||||||
|
x=x.contiguous(),
|
||||||
|
w=w.contiguous(),
|
||||||
|
out=output_tensor,
|
||||||
|
x_trans=x_trans,
|
||||||
|
w_trans=w_trans,
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_lora_matmul_inplace(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
x_trans: bool = False,
|
||||||
|
w_trans: bool = True,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
beta: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
lora_matmul_inplace.register_fake(_fake_lora_matmul_inplace)
|
||||||
|
|||||||
Reference in New Issue
Block a user