forked from EngineX-MetaX/enginex-c_series-vllm
[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
0
vllm/lora/ops/__init__.py
Normal file
0
vllm/lora/ops/__init__.py
Normal file
16
vllm/lora/ops/torch_ops/__init__.py
Normal file
16
vllm/lora/ops/torch_ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
|
||||
from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
|
||||
sgmv_expand, sgmv_expand_slice,
|
||||
sgmv_shrink)
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_expand_slice",
|
||||
"sgmv_shrink",
|
||||
]
|
||||
119
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
119
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
# LoRA adapter and model may add different amounts of padding to output
|
||||
common_len = min(outputs.shape[1], output_tensor.shape[1])
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :common_len] += outputs[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:, :common_len] = outputs[:limit, :common_len]
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
||||
scaling)
|
||||
|
||||
|
||||
def bgmv_shrink(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
slice_offset, slice_size, add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(
|
||||
dtype=output_tensor.dtype)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
|
||||
else:
|
||||
output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
|
||||
12
vllm/lora/ops/triton_ops/__init__.py
Normal file
12
vllm/lora/ops/triton_ops/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
|
||||
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
|
||||
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
|
||||
|
||||
__all__ = [
|
||||
"lora_expand",
|
||||
"lora_shrink",
|
||||
"LoRAKernelMeta",
|
||||
]
|
||||
243
vllm/lora/ops/triton_ops/kernel_utils.py
Normal file
243
vllm/lora/ops/triton_ops/kernel_utils.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Utilities for Punica kernel construction.
|
||||
"""
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
|
||||
b_dtype: tl.constexpr):
|
||||
"""
|
||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||
B (k x n), iterate, through the K dimension to compute the partial/complete
|
||||
matrix block product.
|
||||
If SPLIT_K == 1, the output m x n product is complete.
|
||||
If SPLIT_K > 1, the thread block computes partial outputs. The partial
|
||||
outputs are then atomically summed in the caller code.
|
||||
Args:
|
||||
a_ptr: Array of pointers, identifying rows of A
|
||||
b_ptr: Array of pointers, identifying columns of B
|
||||
ak_stride: K dimension stride of the A matrix
|
||||
bk_stride: K dimension stride of the B matrix
|
||||
K: Length of the K dimension
|
||||
BLOCK_M: M dimension of the output block m x n
|
||||
BLOCK_N: N dimension of the output block m x n
|
||||
BLOCK_K: K dimension atom
|
||||
EVEN_K: True if the blocks of A and B can be loaded without any
|
||||
masking.
|
||||
SPLIT_K: Parameter signifying parallelism in the K dimension.
|
||||
CAST_TYPE: if True, cast the values from the A matrix to the B
|
||||
matrix dtype.
|
||||
b_dtype: datatype of the B matrix
|
||||
"""
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :]
|
||||
< K - k * (BLOCK_K * SPLIT_K),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None]
|
||||
< K - k * (BLOCK_K * SPLIT_K),
|
||||
other=0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
||||
return accumulator
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_expand_kernel(
|
||||
pid_n,
|
||||
lora_index,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice,
|
||||
compute the matrix product and store in the appropriate output location.
|
||||
Given that this is an expand kernel, we don't perform any split-K reduction
|
||||
as the K dimension is assumed to be small.
|
||||
"""
|
||||
|
||||
# ls_d*_ptr can be either an integer or a pointer
|
||||
if SAME_STRIDE:
|
||||
# integer
|
||||
cur_lora_d0_stride = ls_d0_ptr
|
||||
cur_lora_d1_stride = ls_d1_ptr
|
||||
cur_lora_d2_stride = ls_d2_ptr
|
||||
else:
|
||||
# pointer
|
||||
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
||||
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
||||
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
||||
|
||||
# Identify the input_ptr and lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
cur_input_ptr = input_ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(out_ptr.dtype.element_ty))
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
|
||||
offset_k[None, :] * input_d2_stride)
|
||||
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||
offset_k[:, None] * cur_lora_d2_stride +
|
||||
rbn[None, :] * cur_lora_d1_stride)
|
||||
|
||||
# Compute the block matrix product.
|
||||
SPLIT_K = 1
|
||||
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride,
|
||||
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K,
|
||||
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
|
||||
|
||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||
if SLICE_NUM == 1:
|
||||
cur_slice_start = slice_start_loc
|
||||
else:
|
||||
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
|
||||
offset_cn[None, :] * output_d1_stride)
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
|
||||
< (cur_slice_start + N))
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_index,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram,
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice, compute the
|
||||
matrix product and store in the appropriate output location.
|
||||
"""
|
||||
|
||||
# Identify the lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(input_ptr.dtype.element_ty))
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
|
||||
offset_k[None, :] * input_d1_stride)
|
||||
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||
rbn[None, :] * lora_d1_stride +
|
||||
offset_k[:, None] * lora_d2_stride)
|
||||
|
||||
# Compute partial/complete block matrix product.
|
||||
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k,
|
||||
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False,
|
||||
cur_lora_ptr.dtype.element_ty)
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
|
||||
slice_id * output_d0_stride)
|
||||
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
|
||||
None, :] * output_d2_stride
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
||||
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
290
vllm/lora/ops/triton_ops/lora_expand_op.py
Normal file
290
vllm/lora/ops/triton_ops/lora_expand_op.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_loc,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride, # 1
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride, # 1
|
||||
output_hs_ptr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_mn = tl.program_id(axis=0)
|
||||
pid_m = pid_mn % cta_m_num
|
||||
pid_n = (pid_mn // cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||
# qkv linear.
|
||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||
if pid_n * BLOCK_N >= curr_N:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
|
||||
lora_m_indices_start + cta_m_offset)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_expand_kernel(
|
||||
pid_n,
|
||||
lora_id,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
curr_N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SAME_STRIDE,
|
||||
SLICE_NUM,
|
||||
EVEN_K,
|
||||
CAST_TYPE,
|
||||
ADD_INPUTS)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_expand(
|
||||
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
lora_b_weights: list[
|
||||
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
|
||||
output_tensor: torch.
|
||||
Tensor, # shape [num_tokens, hidden_size * num_slices]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (list[torch.Tensor]): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
offset_start (int, optional): Offset start for output_tensor.
|
||||
Defaults to 0.
|
||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||
output tensor. Defaults to False.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for weight in lora_b_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(0) == len(lora_b_weights)
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check.
|
||||
M = inputs.size(1)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||
0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
|
||||
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
|
||||
inputs.device)
|
||||
|
||||
K = lora_b_weights[0].shape[-1] # K= rank
|
||||
ADD_INPUTS = add_inputs
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
CAST_TYPE = False
|
||||
NUM_SLICES = len(lora_b_weights)
|
||||
|
||||
# Triton kernel configs.
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 16
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||
|
||||
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only a few input tokens require
|
||||
# LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_lora_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
MAX_N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
inputs.stride(2),
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
hidden_sizes_tensor,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
NUM_SLICES,
|
||||
same_stride,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="lora_expand",
|
||||
op_func=_lora_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_expand_fake,
|
||||
)
|
||||
lora_expand = torch.ops.vllm.lora_expand
|
||||
|
||||
except AttributeError:
|
||||
lora_expand = _lora_expand
|
||||
148
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
Normal file
148
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
LoRA kernels metadata preparation utilities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAKernelMeta:
|
||||
token_lora_mapping: torch.Tensor
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor
|
||||
active_lora_ids: torch.Tensor
|
||||
num_tokens_per_lora: torch.Tensor
|
||||
lora_token_start_loc: torch.Tensor
|
||||
|
||||
# The V1 architecture uses the traced torch.compile graphs to execute
|
||||
# a forward pass. Things to note about this process,
|
||||
# 1. The tracing infers all python scalar datatype objects into a constant
|
||||
# value.
|
||||
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
|
||||
# is an experimental feature in pytorch)
|
||||
# 3. The internals of torch.ops functions are not traced.
|
||||
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make(max_loras: int, max_num_tokens: int,
|
||||
device: Union[torch.device, str]) -> "LoRAKernelMeta":
|
||||
|
||||
token_lora_mapping = torch.empty(max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# +1 because "no-lora" is also a possibility
|
||||
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
|
||||
# is a possibility.
|
||||
active_lora_ids = torch.empty(max_loras + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# using running example, [3, 10, 5, 2] is a possibility.
|
||||
num_tokens_per_lora = torch.zeros(max_loras + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
# +2 for this because, the first index is always 0.
|
||||
# using running example, lora_token_start_loc
|
||||
# is [0, 3, 13, 18, 20].
|
||||
lora_token_start_loc = torch.zeros(max_loras + 2,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
no_lora_flag_cpu = torch.tensor([False],
|
||||
dtype=torch.bool,
|
||||
device='cpu')
|
||||
|
||||
return LoRAKernelMeta(
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||
active_lora_ids=active_lora_ids,
|
||||
num_tokens_per_lora=num_tokens_per_lora,
|
||||
lora_token_start_loc=lora_token_start_loc,
|
||||
no_lora_flag_cpu=no_lora_flag_cpu)
|
||||
|
||||
def _reset(self):
|
||||
self.active_lora_ids.fill_(-1)
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
Prepare kernel metadata tensors for the current forward pass.
|
||||
|
||||
Args:
|
||||
token_lora_tensor (torch.Tensor): Tensor containing lora indices
|
||||
for each input token.
|
||||
"""
|
||||
|
||||
self._reset()
|
||||
|
||||
# Check and record no-lora case.
|
||||
no_lora = torch.all(token_lora_mapping == -1)
|
||||
self.no_lora_flag_cpu[0] = no_lora
|
||||
|
||||
if no_lora:
|
||||
# Early exit. LoRA kernels will not be run.
|
||||
return
|
||||
|
||||
num_tokens = token_lora_mapping.size(0)
|
||||
|
||||
# copy token lora mapping
|
||||
self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping,
|
||||
non_blocking=True)
|
||||
|
||||
# token_indices_sorted_by_lora_ids
|
||||
_, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping,
|
||||
stable=True)
|
||||
# start gpu transfer
|
||||
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
|
||||
token_indices_sorted_by_lora_ids, non_blocking=True)
|
||||
|
||||
# active_lora_ids, num_tokens_per_lora
|
||||
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
|
||||
non_blocking=True)
|
||||
self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(
|
||||
num_tokens_per_lora, non_blocking=True)
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_(
|
||||
lora_token_start_loc, non_blocking=True)
|
||||
|
||||
def meta_args(
|
||||
self, token_nums: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
forward pass execution of the kernel. The function returns all the
|
||||
metadata required by the kernel, in order, as a tuple, so it can be
|
||||
unpacked directly during the lora_shrink/lora_expand function call.
|
||||
|
||||
Args:
|
||||
token_nums (int): Number of input tokens in the current forward
|
||||
pass.
|
||||
"""
|
||||
return (
|
||||
self.token_lora_mapping[:token_nums],
|
||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||
self.num_tokens_per_lora,
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
)
|
||||
244
vllm/lora/ops/triton_ops/lora_shrink_op.py
Normal file
244
vllm/lora/ops/triton_ops/lora_shrink_op.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
|
||||
lora_token_start_loc, lora_ids, scaling,
|
||||
input_d0_stride, input_d1_stride, lora_d0_stride,
|
||||
lora_d1_stride, lora_d2_stride, output_d0_stride,
|
||||
output_d1_stride, output_d2_stride,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_sk_m_n = tl.program_id(axis=0)
|
||||
pid_sk = pid_sk_m_n % SPLIT_K
|
||||
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
|
||||
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
|
||||
lora_m_indices_start + cta_m_offset)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
SLICE_NUM)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_shrink(
|
||||
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
|
||||
lora_a_weights: list[
|
||||
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
|
||||
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor
|
||||
lora_a_weights (list[torch.Tensor]): LoRA weights
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype == lora_a_weights[0].dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
for weight in lora_a_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check
|
||||
M = inputs.size(0)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||
0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||
NUM_SLICES = len(lora_a_weights)
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
|
||||
# Triton kernel configs
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 256 if M < 128 else 32
|
||||
SPLIT_K = 64 if M < 128 else 8
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only few of the input tokens
|
||||
# require LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_lora_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_strides_d0,
|
||||
lora_strides_d1,
|
||||
lora_strides_d2,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor.stride(2),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
NUM_SLICES,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="lora_shrink",
|
||||
op_func=_lora_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_shrink_fake,
|
||||
)
|
||||
lora_shrink = torch.ops.vllm.lora_shrink
|
||||
|
||||
except AttributeError:
|
||||
lora_shrink = _lora_shrink
|
||||
120
vllm/lora/ops/triton_ops/utils.py
Normal file
120
vllm/lora/ops/triton_ops/utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
|
||||
|
||||
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
|
||||
"""
|
||||
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
"""
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
|
||||
|
||||
if values := _LORA_A_PTR_DICT.get(key):
|
||||
return values
|
||||
|
||||
lora_strides_d0 = []
|
||||
lora_strides_d1 = []
|
||||
lora_strides_d2 = []
|
||||
tensor_ptrs = []
|
||||
for lora_a_weight in lora_a_weights:
|
||||
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_a_weight.size(1) == 1
|
||||
lora_a_weight = lora_a_weight.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_a_weight.is_contiguous()
|
||||
tensor_ptrs.append(lora_a_weight.data_ptr())
|
||||
lora_strides_d0.append(lora_a_weight.stride(0))
|
||||
lora_strides_d1.append(lora_a_weight.stride(1))
|
||||
lora_strides_d2.append(lora_a_weight.stride(2))
|
||||
if len(lora_a_weights) > 1:
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||
else:
|
||||
lora_ptr_tensor = lora_a_weights[0]
|
||||
|
||||
if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1
|
||||
or len(set(lora_strides_d2)) > 1):
|
||||
raise ValueError("All LoRA weights must have the same stride.")
|
||||
|
||||
_LORA_A_PTR_DICT[key] = (
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0[0],
|
||||
lora_strides_d1[0],
|
||||
lora_strides_d2[0],
|
||||
)
|
||||
return _LORA_A_PTR_DICT.get(key)
|
||||
|
||||
|
||||
def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
|
||||
device: torch.device):
|
||||
"""
|
||||
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
|
||||
"""
|
||||
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
||||
if values := _LORA_B_PTR_DICT.get(key):
|
||||
return values
|
||||
slice_offset_lst = []
|
||||
tensor_ptrs = []
|
||||
lora_strides_d0 = []
|
||||
lora_strides_d1 = []
|
||||
lora_strides_d2 = []
|
||||
hidden_sizes = []
|
||||
slice_offset = offset_start
|
||||
for lora_b_weight in lora_weights:
|
||||
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weight.size(1) == 1
|
||||
lora_b_weight = lora_b_weight.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_b_weight.is_contiguous()
|
||||
tensor_ptrs.append(lora_b_weight.data_ptr())
|
||||
lora_strides_d0.append(lora_b_weight.stride(0))
|
||||
lora_strides_d1.append(lora_b_weight.stride(1))
|
||||
lora_strides_d2.append(lora_b_weight.stride(2))
|
||||
slice_offset_lst.append(slice_offset)
|
||||
slice_offset += lora_b_weight.size(1)
|
||||
hidden_sizes.append(lora_b_weight.size(1))
|
||||
|
||||
if len(lora_weights) > 1:
|
||||
# note these are device tensors
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
|
||||
else:
|
||||
slice_start_tensor = slice_offset_lst[0]
|
||||
lora_ptr_tensor = lora_b_weight[0]
|
||||
|
||||
# If each lora has the same stride, there's no need to use a
|
||||
# tensor for storage.
|
||||
if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and
|
||||
len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1:
|
||||
lora_strides_d0_tensor = lora_strides_d0[0]
|
||||
lora_strides_d1_tensor = lora_strides_d1[0]
|
||||
lora_strides_d2_tensor = lora_strides_d2[0]
|
||||
hidden_sizes_tensor = hidden_sizes[0]
|
||||
same_stride = True
|
||||
|
||||
else:
|
||||
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
|
||||
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
|
||||
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
|
||||
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
|
||||
same_stride = False
|
||||
# MAX_N is the maximum hidden size among all the lora_b weights
|
||||
MAX_N = max(hidden_sizes)
|
||||
_LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor,
|
||||
lora_strides_d0_tensor, lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor, hidden_sizes_tensor,
|
||||
same_stride, MAX_N)
|
||||
return _LORA_B_PTR_DICT.get(key)
|
||||
7
vllm/lora/ops/xla_ops/__init__.py
Normal file
7
vllm/lora/ops/xla_ops/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink)
|
||||
|
||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
||||
145
vllm/lora/ops/xla_ops/lora_ops.py
Normal file
145
vllm/lora/ops/xla_ops/lora_ops.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.core.xla_builder as xb
|
||||
from torch.library import impl
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
|
||||
|
||||
|
||||
@jax.jit
|
||||
def bgmv_jax(inputs, loras, idxs):
|
||||
return jnp.einsum(
|
||||
"td,tX,Xld->tl",
|
||||
inputs,
|
||||
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
||||
loras,
|
||||
)
|
||||
|
||||
|
||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "XLA")
|
||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
jax_import_guard()
|
||||
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
|
||||
idxs: torch.IntTensor):
|
||||
T, _ = inputs.shape
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return torch.empty((T, L), device=inputs.device)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
if output_tensor.shape[1] > outputs.shape[1]:
|
||||
outputs = F.pad(outputs,
|
||||
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
|
||||
else:
|
||||
return outputs[:limit, :output_tensor.shape[1]]
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
scaling (float, optional): Scalar multiplier applied to the output.
|
||||
"""
|
||||
|
||||
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights,
|
||||
lora_indices_tensor)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
outputs = F.pad(
|
||||
outputs,
|
||||
(
|
||||
slice_offset,
|
||||
output_tensor.shape[1] - (slice_offset + slice_size),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs
|
||||
else:
|
||||
return outputs
|
||||
Reference in New Issue
Block a user