forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
201
vllm-v0.6.2/vllm/lora/ops/sgmv_expand.py
Normal file
201
vllm-v0.6.2/vllm/lora/ops/sgmv_expand.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
_sgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
||||
_sgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_expand = _sgmv_expand
|
||||
Reference in New Issue
Block a user