forked from EngineX-Cambricon/enginex-mlu370-vllm
169 lines
5.0 KiB
Python
169 lines
5.0 KiB
Python
"""
|
|
Based on:
|
|
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
|
Punica: Multi-Tenant LoRA Serving.
|
|
https://arxiv.org/abs/2310.18547
|
|
"""
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from .utils import get_lora_op_configs
|
|
|
|
|
|
@triton.jit
|
|
def _bgmv_expand_kernel(
|
|
input_ptr,
|
|
lora_ptr,
|
|
out_ptr,
|
|
N,
|
|
K,
|
|
lora_indices,
|
|
xm_stride,
|
|
xk_stride,
|
|
l0_stride,
|
|
lora_k_stride,
|
|
lora_n_stride,
|
|
cm_stride,
|
|
cn_stride,
|
|
BLOCK_N: tl.constexpr,
|
|
BLOCK_K: tl.constexpr,
|
|
SPLIT_N: tl.constexpr,
|
|
EVEN_K: tl.constexpr,
|
|
ADD_INPUTS: tl.constexpr,
|
|
CAST_TYPE: tl.constexpr,
|
|
):
|
|
"""
|
|
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
|
performance
|
|
"""
|
|
pid_sn = tl.program_id(axis=0)
|
|
cur_batch = tl.program_id(axis=1)
|
|
lora_index = tl.load(lora_indices + cur_batch)
|
|
if lora_index == -1:
|
|
return
|
|
offset_k = tl.arange(0, BLOCK_K)
|
|
offset_n = tl.arange(0, BLOCK_N)
|
|
if EVEN_K:
|
|
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
|
offset_k * xk_stride, ) # [BLOCK_K]
|
|
else:
|
|
tiled_a = tl.load(
|
|
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
|
mask=offset_k < K,
|
|
other=0,
|
|
) # [BLOCK_K]
|
|
# N must be divisible by SPLIT_N
|
|
split_n_length = tl.cdiv(N, SPLIT_N)
|
|
if CAST_TYPE:
|
|
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
|
# sliding to next row-block
|
|
b_ptr = (lora_ptr + l0_stride * lora_index +
|
|
pid_sn * split_n_length * lora_k_stride)
|
|
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
|
|
for n in range(0, split_n_length, BLOCK_N):
|
|
current_n = n + offset_n
|
|
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
|
|
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
|
< K)
|
|
c_mask = current_n < split_n_length
|
|
tiled_b = tl.load(
|
|
b_ptr + current_n_c[:, None] * lora_k_stride +
|
|
offset_k[None, :] * lora_n_stride,
|
|
mask=b_ptr_mask,
|
|
other=0.0,
|
|
) # [BLOCK_N,BLOCK_K]
|
|
if ADD_INPUTS:
|
|
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
|
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
|
else:
|
|
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
|
|
|
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def _bgmv_expand(
|
|
inputs: torch.Tensor,
|
|
lora_b_weights: torch.Tensor,
|
|
output_tensor: torch.Tensor,
|
|
lora_indices_tensor: torch.Tensor,
|
|
add_inputs: bool = True,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
inputs (torch.Tensor): input tensor
|
|
lora_b_weights (torch.Tensor): lora'a weight
|
|
output_tensor (torch.Tensor): output tensor
|
|
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
|
corresponding to each batch, An index of -1 means no lora should be
|
|
applied.
|
|
batches (int): batch size
|
|
add_inputs (bool, optional): Defaults to False, adds the final lora
|
|
results to the output.
|
|
"""
|
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
assert lora_b_weights.dtype in [
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
]
|
|
assert inputs.size(1) == lora_b_weights.size(-1)
|
|
|
|
assert inputs.is_contiguous()
|
|
assert output_tensor.is_contiguous()
|
|
|
|
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
|
assert lora_b_weights.size(1) == 1
|
|
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
|
else:
|
|
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
|
assert lora_b_weights.is_contiguous()
|
|
|
|
# TODO tuning this config
|
|
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
|
BLOCK_K = triton.next_power_of_2(K)
|
|
EVEN_K = K % BLOCK_K == 0
|
|
ADD_INPUTS = add_inputs
|
|
CAST_TYPE = False
|
|
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
]:
|
|
CAST_TYPE = True
|
|
batches = lora_indices_tensor.size(0)
|
|
config = get_lora_op_configs("expand", batches, N)
|
|
grid = lambda META: (
|
|
META["SPLIT_N"],
|
|
batches,
|
|
)
|
|
_bgmv_expand_kernel[grid](
|
|
inputs,
|
|
lora_b_weights,
|
|
output_tensor,
|
|
N,
|
|
K,
|
|
lora_indices_tensor,
|
|
inputs.stride(0),
|
|
inputs.stride(1),
|
|
lora_b_weights.stride(0),
|
|
lora_b_weights.stride(1),
|
|
lora_b_weights.stride(2),
|
|
output_tensor.stride(0),
|
|
output_tensor.stride(1),
|
|
BLOCK_K=BLOCK_K,
|
|
EVEN_K=EVEN_K,
|
|
ADD_INPUTS=ADD_INPUTS,
|
|
CAST_TYPE=CAST_TYPE,
|
|
**config,
|
|
)
|
|
return
|
|
|
|
|
|
try:
|
|
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
|
_bgmv_expand,
|
|
mutates_args=["output_tensor"])
|
|
except AttributeError:
|
|
bgmv_expand = _bgmv_expand
|