forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
150
vllm-v0.6.2/vllm/lora/ops/bgmv_shrink.py
Normal file
150
vllm-v0.6.2/vllm/lora/ops/bgmv_shrink.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sk = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
|
||||
a_ptr = input_ptr + cur_batch * xm_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index
|
||||
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_K * SPLIT_K):
|
||||
current_k = k + offset_k
|
||||
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
|
||||
tiled_a = tl.load(
|
||||
a_ptr + current_k_c,
|
||||
mask=current_k < K,
|
||||
other=0.0,
|
||||
) # [BLOCK_K]
|
||||
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
|
||||
|
||||
tiled_b = tl.load(
|
||||
b_ptr + offset_n[:, None] * lora_k_stride +
|
||||
current_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
accumulator += tl.sum(tiled_a * tiled_b, 1)
|
||||
accumulator *= scaling
|
||||
offset_cn = tl.arange(0, BLOCK_N)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
|
||||
c_mask = offset_cn < N
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
batches = lora_indices_tensor.size(0)
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_N = triton.next_power_of_2(N)
|
||||
# First try to load optimal config from the file
|
||||
config = get_lora_op_configs("bgmv_shrink", batches, K)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_K"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_N=BLOCK_N,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
||||
_bgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_shrink = _bgmv_shrink
|
||||
Reference in New Issue
Block a user