[Model] Support DeepSeek-V4
This commit is contained in:
41
vllm_mlu/lora/__init__.py
Normal file
41
vllm_mlu/lora/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.layers.column_parallel_linear import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import LoRAMapping
|
||||
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
|
||||
|
||||
__all__ = [
|
||||
"BaseLayerWithLoRA",
|
||||
"VocabParallelEmbeddingWithLoRA",
|
||||
"LogitsProcessorWithLoRA",
|
||||
"ColumnParallelLinearWithLoRA",
|
||||
"ColumnParallelLinearWithShardedLoRA",
|
||||
"MergedColumnParallelLinearWithLoRA",
|
||||
"MergedColumnParallelLinearWithShardedLoRA",
|
||||
"MergedQKVParallelLinearWithLoRA",
|
||||
"MergedQKVParallelLinearWithShardedLoRA",
|
||||
"QKVParallelLinearWithLoRA",
|
||||
"QKVParallelLinearWithShardedLoRA",
|
||||
"RowParallelLinearWithLoRA",
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"FusedMoEWithLoRA",
|
||||
]
|
||||
3
vllm_mlu/lora/layers/__init__.py
Normal file
3
vllm_mlu/lora/layers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
50
vllm_mlu/lora/layers/base_linear.py
Normal file
50
vllm_mlu/lora/layers/base_linear.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual in matmul
|
||||
'''
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# In transformers backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
if x.ndim == 3 and output.ndim == 3:
|
||||
output = output.flatten(0, 1)
|
||||
x = x.flatten(0, 1)
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
return output
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
BaseLinearLayerWithLoRA,
|
||||
BaseLinearLayerWithLoRA.apply,
|
||||
vllm__lora__layers__row_parallel_linear__BaseLinearLayerWithLoRA__apply,
|
||||
)
|
||||
39
vllm_mlu/lora/layers/column_parallel_linear.py
Normal file
39
vllm_mlu/lora/layers/column_parallel_linear.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers.column_parallel_linear import ColumnParallelLinearWithLoRA
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org = ColumnParallelLinearWithLoRA.forward
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add smooth_quant_scale and use_tp_weight parameters.
|
||||
'''
|
||||
def vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward(
|
||||
self,
|
||||
input_,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert not use_tp_weight, "LoRa does not support use_tp_weight yet."
|
||||
assert smooth_quant_scale is None, "LoRA does not support smooth quant yet."
|
||||
return vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward_org(self, input_)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithLoRA.forward,
|
||||
vllm__lora__layers__column_parallel_linear__ColumnParallelLinearWithLoRA__forward,
|
||||
)
|
||||
163
vllm_mlu/lora/layers/row_parallel_linear.py
Normal file
163
vllm_mlu/lora/layers/row_parallel_linear.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual and bias in matmul
|
||||
'''
|
||||
output = self.base_layer.quant_method.apply(
|
||||
self.base_layer, x, bias, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
|
||||
buffer, x, self.lora_a_stacked, 1.0
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
buffer = shrunk_buffer
|
||||
if self.tp_size > 1:
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
# NOTE offset are based on the rank.
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
offset_start = self.tp_rank * shard_size
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
self.output_slices,
|
||||
offset_start=offset_start,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
smooth_quant_scale: torch.Tensor | None = None,
|
||||
use_tp_weight: bool = False,
|
||||
output: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add parameters `residual`, `smooth_quant_scale`, `use_tp_weight` and `output`
|
||||
to keep parameters consistent with RowParallelLinear.forward.
|
||||
'''
|
||||
assert (not use_tp_weight) and output is None, (
|
||||
f"RowParallelLinearWithLoRA.forward does not support use_tp_wight=True"
|
||||
f" or pass output parameters.")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# Set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) apply residual fusion in matmul like RowParallelLinear
|
||||
2) add bias in matmul, not after all reduce
|
||||
'''
|
||||
# Matrix multiply.
|
||||
bias_ = (
|
||||
None if (self.base_layer.tp_rank > 0 or self.base_layer.skip_bias_add)
|
||||
else self.base_layer.bias
|
||||
)
|
||||
residual_ = None if self.base_layer.tp_rank > 0 else residual
|
||||
output_parallel = self.apply(input_parallel, bias_, residual_)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.base_layer.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: do not add bias after all_reduce
|
||||
'''
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
RowParallelLinearWithShardedLoRA.apply,
|
||||
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithShardedLoRA__apply,
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA.forward,
|
||||
vllm__lora__layers__row_parallel_linear__RowParallelLinearWithLoRA__forward,
|
||||
)
|
||||
3
vllm_mlu/lora/ops/__init__.py
Normal file
3
vllm_mlu/lora/ops/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
16
vllm_mlu/lora/ops/triton_ops/__init__.py
Normal file
16
vllm_mlu/lora/ops/triton_ops/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.sgmv_expand import sgmv_expand_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops.lora_shrink_op import lora_shrink
|
||||
from vllm_mlu.lora.ops.triton_ops.lora_expand_op import lora_expand
|
||||
|
||||
__all__ = [
|
||||
"sgmv_expand_mlu",
|
||||
"sgmv_expand_slice_mlu",
|
||||
"sgmv_shrink_mlu",
|
||||
"lora_expand",
|
||||
"lora_shrink"
|
||||
]
|
||||
308
vllm_mlu/lora/ops/triton_ops/kernel_utils.py
Normal file
308
vllm_mlu/lora/ops/triton_ops/kernel_utils.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Utilities for Punica kernel construction.
|
||||
"""
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: modify mm triton
|
||||
1) add parameter offset_n: mlu add offset_n of matrix B,
|
||||
value: tl.arange(0, BLOCK_N) + pid_n * BLOCK_N, shape: [BLOCK_N]
|
||||
add parameter N: mlu add column number of matrix B
|
||||
2) tiled_b always need mask in case offset_n > N
|
||||
'''
|
||||
|
||||
@triton.jit
|
||||
def mm_k(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
ak_stride,
|
||||
bk_stride,
|
||||
offset_n,
|
||||
offset_k,
|
||||
K: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
b_dtype: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||
B (k x n), iterate, through the K dimension to compute the partial/complete
|
||||
matrix block product.
|
||||
If SPLIT_K == 1, the output m x n product is complete.
|
||||
If SPLIT_K > 1, the thread block computes partial outputs. The partial
|
||||
outputs are then atomically summed in the caller code.
|
||||
Args:
|
||||
a_ptr: Array of pointers, identifying rows of A
|
||||
b_ptr: Array of pointers, identifying columns of B
|
||||
ak_stride: K dimension stride of the A matrix
|
||||
bk_stride: K dimension stride of the B matrix
|
||||
K: Length of the K dimension
|
||||
BLOCK_M: M dimension of the output block m x n
|
||||
BLOCK_N: N dimension of the output block m x n
|
||||
BLOCK_K: K dimension atom
|
||||
EVEN_K: True if the blocks of A and B can be loaded without any
|
||||
masking.
|
||||
SPLIT_K: Parameter signifying parallelism in the K dimension.
|
||||
CAST_TYPE: if True, cast the values from the A matrix to the B
|
||||
matrix dtype.
|
||||
b_dtype: datatype of the B matrix
|
||||
"""
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N, other=0.0)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :]
|
||||
< K - k * (BLOCK_K * SPLIT_K),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=(offset_k[:, None]
|
||||
< K - k * (BLOCK_K * SPLIT_K)) & (offset_n < N)[None, :],
|
||||
other=0.0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
||||
return accumulator
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_expand_kernel(
|
||||
pid_n,
|
||||
lora_index,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice,
|
||||
compute the matrix product and store in the appropriate output location.
|
||||
Given that this is an expand kernel, we don't perform any split-K reduction
|
||||
as the K dimension is assumed to be small.
|
||||
"""
|
||||
|
||||
# ls_d*_ptr can be either an integer or a pointer
|
||||
if SAME_STRIDE: # 'same_stride': True
|
||||
# integer
|
||||
cur_lora_d0_stride = ls_d0_ptr
|
||||
cur_lora_d1_stride = ls_d1_ptr
|
||||
cur_lora_d2_stride = ls_d2_ptr
|
||||
else:
|
||||
# pointer
|
||||
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
||||
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
||||
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
||||
|
||||
# Identify the input_ptr and lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
cur_input_ptr = input_ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(out_ptr.dtype.element_ty))
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
|
||||
will handle as head corruption
|
||||
2) re-write b_ptr, use offset_n to identify its position
|
||||
'''
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
|
||||
offset_k[None, :] * input_d2_stride)
|
||||
# b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||
# offset_k[:, None] * cur_lora_d2_stride +
|
||||
# rbn[None, :] * cur_lora_d1_stride)
|
||||
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
|
||||
offset_k[:, None] * cur_lora_d2_stride +
|
||||
offset_n[None, :] * cur_lora_d1_stride)
|
||||
|
||||
# Compute the block matrix product.
|
||||
SPLIT_K = 1
|
||||
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride, offset_n,
|
||||
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N,
|
||||
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||
if SLICE_NUM == 1:
|
||||
cur_slice_start = slice_start_loc
|
||||
else:
|
||||
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
|
||||
offset_cn[None, :] * output_d1_stride)
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
|
||||
< (cur_slice_start + N))
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_index,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram,
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice, compute the
|
||||
matrix product and store in the appropriate output location.
|
||||
"""
|
||||
|
||||
# Identify the lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(input_ptr.dtype.element_ty))
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: 1) remove rbn definition: mlu doesn't support contiguous and
|
||||
will handle as head corruption
|
||||
2) re-write b_ptr, use offset_n to identify its position
|
||||
'''
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
# rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
|
||||
offset_k[None, :] * input_d1_stride)
|
||||
# b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||
# rbn[None, :] * lora_d1_stride +
|
||||
# offset_k[:, None] * lora_d2_stride)
|
||||
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
|
||||
offset_n[None, :] * lora_d1_stride +
|
||||
offset_k[:, None] * lora_d2_stride)
|
||||
|
||||
# Compute partial/complete block matrix product.
|
||||
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_n, offset_k,
|
||||
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, N, False,
|
||||
cur_lora_ptr.dtype.element_ty)
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
|
||||
slice_id * output_d0_stride)
|
||||
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
|
||||
None, :] * output_d2_stride
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
||||
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
308
vllm_mlu/lora/ops/triton_ops/lora_expand_op.py
Normal file
308
vllm_mlu/lora/ops/triton_ops/lora_expand_op.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use vllm_mlu hijacked kernel
|
||||
'''
|
||||
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_loc,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride, # 1
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride, # 1
|
||||
output_hs_ptr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
):
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_mn = tl.program_id(axis=0)
|
||||
pid_m = pid_mn % cta_m_num
|
||||
pid_n = (pid_mn // cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||
# qkv linear.
|
||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||
if pid_n * BLOCK_N >= curr_N:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (
|
||||
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
|
||||
)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_expand_kernel(
|
||||
pid_n,
|
||||
lora_id,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
curr_N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SAME_STRIDE,
|
||||
SLICE_NUM,
|
||||
EVEN_K,
|
||||
CAST_TYPE,
|
||||
ADD_INPUTS,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_expand(
|
||||
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
|
||||
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (list[torch.Tensor]): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
offset_start (int, optional): Offset start for output_tensor.
|
||||
Defaults to 0.
|
||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||
output tensor. Defaults to False.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for weight in lora_b_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(0) == len(lora_b_weights)
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check.
|
||||
M = inputs.size(1)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(
|
||||
slice_start_tensor,
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
hidden_sizes_tensor,
|
||||
same_stride,
|
||||
MAX_N,
|
||||
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
|
||||
|
||||
K = lora_b_weights[0].shape[-1] # K= rank
|
||||
ADD_INPUTS = add_inputs
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
CAST_TYPE = False
|
||||
NUM_SLICES = len(lora_b_weights)
|
||||
|
||||
# Triton kernel configs.
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 16
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||
|
||||
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only a few input tokens require
|
||||
# LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_lora_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
MAX_N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
inputs.stride(2),
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
hidden_sizes_tensor,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
NUM_SLICES,
|
||||
same_stride,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use only vllm operand
|
||||
'''
|
||||
|
||||
lora_expand = _lora_expand
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
258
vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py
Normal file
258
vllm_mlu/lora/ops/triton_ops/lora_shrink_op.py
Normal file
@@ -0,0 +1,258 @@
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use vllm_mlu hijacked kernel
|
||||
'''
|
||||
from vllm_mlu.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
|
||||
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
|
||||
lora_token_start_loc, lora_ids, scaling,
|
||||
input_d0_stride, input_d1_stride, lora_d0_stride,
|
||||
lora_d1_stride, lora_d2_stride, output_d0_stride,
|
||||
output_d1_stride, output_d2_stride,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
|
||||
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_sk_m_n = tl.program_id(axis=0)
|
||||
pid_sk = pid_sk_m_n % SPLIT_K
|
||||
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
|
||||
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
|
||||
lora_m_indices_start + cta_m_offset)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
SLICE_NUM)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_shrink(
|
||||
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
|
||||
lora_a_weights: list[
|
||||
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
|
||||
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor
|
||||
lora_a_weights (list[torch.Tensor]): LoRA weights
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype == lora_a_weights[0].dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
for weight in lora_a_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check
|
||||
M = inputs.size(0)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
|
||||
0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
|
||||
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||
NUM_SLICES = len(lora_a_weights)
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
|
||||
# Triton kernel configs
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 256 if M < 128 else 32
|
||||
SPLIT_K = 64 if M < 128 else 8
|
||||
NUM_WARPS = 4
|
||||
NUM_CTAS = 1
|
||||
NUM_STAGES = 2
|
||||
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only few of the input tokens
|
||||
# require LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
)
|
||||
|
||||
_lora_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_strides_d0,
|
||||
lora_strides_d1,
|
||||
lora_strides_d2,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor.stride(2),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
NUM_SLICES,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use only vllm operand
|
||||
'''
|
||||
|
||||
lora_shrink = _lora_shrink
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
238
vllm_mlu/lora/ops/triton_ops/sgmv_expand.py
Normal file
238
vllm_mlu/lora/ops/triton_ops/sgmv_expand.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
|
||||
offset_k[None, :] * xk_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index + \
|
||||
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
|
||||
other=0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: Adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_expand_kernel_mlu
|
||||
'''
|
||||
_sgmv_expand_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
248
vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py
Normal file
248
vllm_mlu/lora/ops/triton_ops/sgmv_expand_slice.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_slice_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||
might be that in the future, we could implement a fusion operator to
|
||||
achieve the current functionality instead of having to call it multiple
|
||||
times.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
|
||||
offset_k[None, :] * xk_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index + \
|
||||
offset_k[:, None] * lora_n_stride + offset_n[None, :] * lora_k_stride
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < K - k * BLOCK_K) & (offset_m[:, None] < M)),
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < K - k * BLOCK_K) & (offset_n[None, :] < N)),
|
||||
other=0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||
(slice_offset + N))
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_slice_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: Adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 32)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_expand_kernel_mlu
|
||||
'''
|
||||
_sgmv_expand_slice_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
231
vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py
Normal file
231
vllm_mlu/lora/ops/triton_ops/sgmv_shrink.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_mlu.lora.ops.triton_ops.utils import adjust_kernel_block_size
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_shrink_kernel_mlu(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride, # hidden_size
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||
introducing SPLIT-K can improve performance
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_sk = tl.program_id(axis=1)
|
||||
cur_batch = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
a_ptr = input_ptr + cur_seq_start * xm_stride + offset_m[:, None] * xm_stride + \
|
||||
offset_k[None, :] * xk_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index + offset_n[None, :] * lora_k_stride + \
|
||||
offset_k[:, None] * lora_n_stride
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: adjust kernel impl to fit mlu.
|
||||
'''
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr, mask=offset_m[:, None] < M)
|
||||
tiled_b = tl.load(b_ptr, mask=offset_n[None, :] < N)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=((offset_k[None, :] < k_remaining) & (offset_m[:, None] < M)),
|
||||
other=0.0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=((offset_k[:, None] < k_remaining) & (offset_n[None, :] < N)),
|
||||
other=0.0)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += BLOCK_K * SPLIT_K * xk_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_shrink_mlu(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Workaround: adjust block size to meet mlu restrictions.
|
||||
|
||||
The grid of mlu triton kernel must less than 65536, it will be out of bound when
|
||||
the input seq is very long, and causes runtime error. So we need to adjust the block
|
||||
size to avoid this.
|
||||
'''
|
||||
BLOCK_M, BLOCK_N = adjust_kernel_block_size(max_seq_length, 32, N, 16)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
BLOCK_K = 32
|
||||
SPLIT_K = 8
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
SPLIT_K,
|
||||
batches,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: call _sgmv_shrink_kernel_mlu
|
||||
'''
|
||||
_sgmv_shrink_kernel_mlu[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return
|
||||
41
vllm_mlu/lora/ops/triton_ops/utils.py
Normal file
41
vllm_mlu/lora/ops/triton_ops/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Tuple
|
||||
from math import ceil
|
||||
|
||||
_MLU_MAX_GRID_SIZE = 65536
|
||||
|
||||
def adjust_kernel_block_size(
|
||||
m: int,
|
||||
block_m: int,
|
||||
n: int,
|
||||
block_n: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Adjust block size to meet mlu triton grid restrictions.
|
||||
|
||||
Calculation of the max block size in candidates list:
|
||||
|
||||
LLama3.1-8b-tp1 max n is 14336
|
||||
LLama3.1-70b-tp4 max n is 7168
|
||||
LLama3.1-405b-tp8 max n is 6656
|
||||
|
||||
when n is 14336, the max sequence length of block size 256 can be
|
||||
floor(65536 / ceil(14336 / 256)) * 256 = 299520.
|
||||
"""
|
||||
candidates_list = [16, 32, 64, 96, 128, 192, 256]
|
||||
candidates_list_len = len(candidates_list)
|
||||
m_idx = 1
|
||||
n_idx = 0 if block_n == 16 else 1
|
||||
while m_idx < candidates_list_len and n_idx < candidates_list_len:
|
||||
block_m = candidates_list[m_idx]
|
||||
block_n = candidates_list[n_idx]
|
||||
if ceil(m / block_m) * ceil(n / block_n) < _MLU_MAX_GRID_SIZE:
|
||||
break
|
||||
if m_idx < candidates_list_len:
|
||||
m_idx += 1
|
||||
if n_idx < candidates_list_len:
|
||||
n_idx += 1
|
||||
if ceil(m / block_m) * ceil(n / block_n) >= _MLU_MAX_GRID_SIZE:
|
||||
raise ValueError(f"the max seq len {m} is too long for lora triton kernel")
|
||||
return block_m, block_n
|
||||
3
vllm_mlu/lora/punica_wrapper/__init__.py
Normal file
3
vllm_mlu/lora/punica_wrapper/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
89
vllm_mlu/lora/punica_wrapper/punica_mlu.py
Normal file
89
vllm_mlu/lora/punica_wrapper/punica_mlu.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union, final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops import sgmv_expand_slice_mlu
|
||||
from vllm_mlu.lora.ops.triton_ops import sgmv_shrink_mlu
|
||||
|
||||
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
||||
|
||||
|
||||
@final
|
||||
class PunicaWrapperMLU(PunicaWrapperCPU):
|
||||
"""
|
||||
PunicaWrapperMLU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica triton kernel.
|
||||
"""
|
||||
|
||||
def _shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def _expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
def _expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
add_inputs: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice_mlu(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
Reference in New Issue
Block a user