Sync from v0.13
This commit is contained in:
0
vllm/lora/ops/__init__.py
Normal file
0
vllm/lora/ops/__init__.py
Normal file
6
vllm/lora/ops/ipex_ops/__init__.py
Normal file
6
vllm/lora/ops/ipex_ops/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
|
||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
||||
57
vllm/lora/ops/ipex_ops/lora_ops.py
Normal file
57
vllm/lora/ops/ipex_ops/lora_ops.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except ImportError as e:
|
||||
raise e
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
) -> None:
|
||||
ipex.llm.functional.bgmv_shrink(
|
||||
inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
ipex.llm.functional.bgmv_expand(
|
||||
inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
ipex.llm.functional.bgmv_expand_slice(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
lora_indices_tensor,
|
||||
slice_offset,
|
||||
slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
20
vllm/lora/ops/torch_ops/__init__.py
Normal file
20
vllm/lora/ops/torch_ops/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.torch_ops.lora_ops import (
|
||||
bgmv_expand, # noqa: F401
|
||||
bgmv_expand_slice,
|
||||
bgmv_shrink,
|
||||
sgmv_expand,
|
||||
sgmv_expand_slice,
|
||||
sgmv_shrink,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"bgmv_expand",
|
||||
"bgmv_expand_slice",
|
||||
"bgmv_shrink",
|
||||
"sgmv_expand",
|
||||
"sgmv_expand_slice",
|
||||
"sgmv_shrink",
|
||||
]
|
||||
128
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
128
vllm/lora/ops/torch_ops/lora_ops.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
# LoRA adapter and model may add different amounts of padding to output
|
||||
common_len = min(outputs.shape[1], output_tensor.shape[1])
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :common_len] += outputs[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:, :common_len] = outputs[:limit, :common_len]
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling)
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
|
||||
|
||||
|
||||
def sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
exploded_indices,
|
||||
slice_offset,
|
||||
slice_size,
|
||||
add_inputs,
|
||||
)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
|
||||
else:
|
||||
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
|
||||
60
vllm/lora/ops/triton_ops/README_TUNING.md
Normal file
60
vllm/lora/ops/triton_ops/README_TUNING.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# Multi-LoRA Tuning
|
||||
|
||||
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`.
|
||||
Without this, the shrink/expand kernels will use default configurations.
|
||||
|
||||
## Tuning Process
|
||||
|
||||
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from
|
||||
[Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
|
||||
|
||||
1. Define the searching space. Here is an example of searching space:
|
||||
|
||||
```python
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [32, 64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
num_ctas_range = [1]
|
||||
split_k_range = [4, 8, 16, 32, 64]
|
||||
```
|
||||
|
||||
2. Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
|
||||
|
||||
For example, you can acquire the info by simply checking
|
||||
[add_lora_linear](https://github.com/vllm-project/vllm/blob/main/vllm/lora/punica_wrapper/punica_gpu.py#L181):
|
||||
|
||||
```python
|
||||
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
|
||||
print(f"num_slices: {len(output_slices)}")
|
||||
for i in range(len(output_slices)):
|
||||
print(f"a{i} shape: {lora_a_stacked[i].shape}")
|
||||
print(f"b{i} shape: {lora_b_stacked[i].shape}")
|
||||
print("y_shape", y.shape)
|
||||
```
|
||||
|
||||
3. Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space
|
||||
by performing a grid search to find the optimal kernel configuration.
|
||||
vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py)
|
||||
can be used to search for configurations for different shapes.
|
||||
|
||||
## Config Files
|
||||
|
||||
### File Naming
|
||||
|
||||
| Kernel Type | File Name Template | Example |
|
||||
|---------------------------|--------------------------------------------|---------------------------------------------|
|
||||
| shrink | `{gpu_name}_SHRINK.json` | `NVIDIA_H200_SHRINK.json` |
|
||||
| expand | `{gpu_name}_EXPAND_{add_input}.json` | `NVIDIA_H200_EXPAND_TRUE.json` |
|
||||
| fused_moe_lora_w13_shrink | `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json` | `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json` |
|
||||
| fused_moe_lora_w13_expand | `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json` | `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json` |
|
||||
| fused_moe_lora_w2_shrink | `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json` | `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json` |
|
||||
| fused_moe_lora_w2_expand | `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json` | `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json` |
|
||||
|
||||
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`.
|
||||
|
||||
### JSON Structure
|
||||
|
||||
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]`,
|
||||
where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer.
|
||||
21
vllm/lora/ops/triton_ops/__init__.py
Normal file
21
vllm/lora/ops/triton_ops/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
|
||||
fused_moe_lora,
|
||||
fused_moe_lora_expand,
|
||||
fused_moe_lora_shrink,
|
||||
)
|
||||
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
|
||||
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
|
||||
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
|
||||
|
||||
__all__ = [
|
||||
"lora_expand",
|
||||
"lora_shrink",
|
||||
"LoRAKernelMeta",
|
||||
"fused_moe_lora",
|
||||
"fused_moe_lora_shrink",
|
||||
"fused_moe_lora_expand",
|
||||
]
|
||||
665
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Normal file
665
vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Normal file
@@ -0,0 +1,665 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
|
||||
|
||||
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
"""
|
||||
`_LORA_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
"""
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
||||
|
||||
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
|
||||
return ptr_tensor
|
||||
|
||||
tensor_ptrs = []
|
||||
for lora_weight in lora_weights:
|
||||
tensor_ptrs.append(lora_weight.data_ptr())
|
||||
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
|
||||
|
||||
_LORA_PTR_DICT[key] = ptr_tensor
|
||||
return _LORA_PTR_DICT.get(key)
|
||||
|
||||
|
||||
@triton.jit(
|
||||
do_not_specialize=[
|
||||
"num_valid_tokens",
|
||||
"EM",
|
||||
"stride_tl",
|
||||
"stride_el",
|
||||
"slice_a_size",
|
||||
"slice_c_size",
|
||||
]
|
||||
)
|
||||
def _fused_moe_lora_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bl,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_tl,
|
||||
stride_el,
|
||||
slice_a_size,
|
||||
slice_c_size,
|
||||
# Meta-parameters
|
||||
num_slice_a: tl.constexpr,
|
||||
num_slice_c: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
launch_pdl: tl.constexpr,
|
||||
IS_PRIMARY: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
moe_enabled = tl.load(adapter_enabled + lora_id)
|
||||
if moe_enabled == 0:
|
||||
# Early exit for the no moe lora case.
|
||||
return
|
||||
max_loras = tl.num_programs(axis=2)
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
pid_sk = pid % SPLIT_K
|
||||
pid_m_n = pid // SPLIT_K
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid_m_n // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
# get the expert_id to process curr shard
|
||||
ind = lora_id * stride_el + pid_m
|
||||
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
|
||||
if expert_id == -1:
|
||||
return
|
||||
# get a_ptr,b_ptr,c_ptr
|
||||
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_id + offs_token_id
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
|
||||
)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
# get a_ptrs,b_ptrs
|
||||
a_ptrs = cur_a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_id * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, grid_k):
|
||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
# pre-fetch lora weight
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if USE_GDC and IS_PRIMARY:
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
tl.extra.cuda.gdc_launch_dependents()
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora_shrink(
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
# (num_slices, num_tokens, top_k_num, max_lora_rank)
|
||||
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
|
||||
lora_a_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
## adding for kernel
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
use_gdc = supports_pdl(qcurr_hidden_states.device)
|
||||
shrink_config = {
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"SPLIT_K": split_k,
|
||||
"USE_GDC": use_gdc,
|
||||
"launch_pdl": use_gdc, # triton kernel metadata
|
||||
}
|
||||
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
|
||||
grid = lambda META: (
|
||||
split_k
|
||||
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_a_stacked),
|
||||
lora_a_stacked[0].shape[0],
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
qcurr_hidden_states,
|
||||
b_ptr,
|
||||
a_intermediate_cache1,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
qcurr_hidden_states.stride(0),
|
||||
qcurr_hidden_states.stride(1),
|
||||
w1_lora_a_stacked.stride(0),
|
||||
w1_lora_a_stacked.stride(1),
|
||||
w1_lora_a_stacked.stride(3),
|
||||
w1_lora_a_stacked.stride(2),
|
||||
a_intermediate_cache1.stride(2),
|
||||
a_intermediate_cache1.stride(3),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
slice_c_size=a_intermediate_cache1.numel() // num_slices,
|
||||
num_slice_a=1,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
IS_PRIMARY=True,
|
||||
**shrink_config,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora_expand(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
## adding for kernel
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
max_lora_rank: int,
|
||||
w1_output_dim_size: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||
K = max_lora_rank
|
||||
N = w1_output_dim_size
|
||||
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
|
||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
)
|
||||
|
||||
b_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, w1_output_dim_size),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
use_gdc = supports_pdl(a_intermediate_cache1.device)
|
||||
expand_config = {
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
|
||||
"USE_GDC": use_gdc,
|
||||
"launch_pdl": use_gdc, # triton kernel metadata
|
||||
}
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_b_stacked),
|
||||
lora_b_stacked[0].shape[0],
|
||||
)
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
b_ptr,
|
||||
b_intermediate_cache1,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
a_intermediate_cache1.stride(0),
|
||||
a_intermediate_cache1.stride(1),
|
||||
w1_lora_b_stacked.stride(0),
|
||||
w1_lora_b_stacked.stride(1),
|
||||
w1_lora_b_stacked.stride(3),
|
||||
w1_lora_b_stacked.stride(2),
|
||||
b_intermediate_cache1.stride(2),
|
||||
b_intermediate_cache1.stride(3),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
slice_a_size=a_intermediate_cache1.numel() // num_slices,
|
||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
IS_PRIMARY=False,
|
||||
**expand_config,
|
||||
)
|
||||
for i in range(num_slices):
|
||||
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
|
||||
lora_a_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, N, max_lora_rank,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
shrink_block_size_k: int,
|
||||
shrink_group_size_m: int,
|
||||
shrink_num_warps: int,
|
||||
shrink_num_stages: int,
|
||||
shrink_split_k: int,
|
||||
expand_block_size_m: int,
|
||||
expand_block_size_n: int,
|
||||
expand_block_size_k: int,
|
||||
expand_group_size_m: int,
|
||||
expand_num_warps: int,
|
||||
expand_num_stages: int,
|
||||
expand_split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
fully_sharded: bool = False,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||
assert (
|
||||
sorted_token_ids.dim()
|
||||
== expert_ids.dim()
|
||||
== topk_weights.dim()
|
||||
== qcurr_hidden_states.dim()
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
sorted_token_ids.shape[0]
|
||||
== expert_ids.shape[0]
|
||||
== num_tokens_post_padded.shape[0]
|
||||
)
|
||||
assert output.shape[0] == topk_weights.shape[0]
|
||||
assert top_k_num == topk_weights.shape[1]
|
||||
device = qcurr_hidden_states.device
|
||||
num_slices = len(lora_a_stacked)
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
num_experts = lora_a_stacked[0].shape[1]
|
||||
N = max_lora_rank
|
||||
M = topk_weights.shape[0]
|
||||
EM = sorted_token_ids.shape[1]
|
||||
K = qcurr_hidden_states.shape[1]
|
||||
num_tokens = M * top_k_num
|
||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||
|
||||
a_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, max_lora_rank),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
_fused_moe_lora_shrink(
|
||||
a_intermediate_cache1,
|
||||
qcurr_hidden_states,
|
||||
lora_a_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
## adding for kernel
|
||||
device,
|
||||
N,
|
||||
M,
|
||||
EM,
|
||||
K,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
num_slices,
|
||||
shrink_block_size_m,
|
||||
shrink_block_size_n,
|
||||
shrink_block_size_k,
|
||||
shrink_group_size_m,
|
||||
shrink_num_warps,
|
||||
shrink_num_stages,
|
||||
shrink_split_k,
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
if fully_sharded:
|
||||
if max_lora_rank == w1_lora_b_stacked.shape[-1]:
|
||||
a_intermediate_cache1 = tensor_model_parallel_all_reduce(
|
||||
a_intermediate_cache1
|
||||
)
|
||||
else:
|
||||
a_intermediate_cache1 = tensor_model_parallel_all_gather(
|
||||
a_intermediate_cache1
|
||||
)
|
||||
|
||||
# reset max_lora_rank to the full rank after allgather
|
||||
max_lora_rank = a_intermediate_cache1.shape[-1]
|
||||
|
||||
_fused_moe_lora_expand(
|
||||
output,
|
||||
a_intermediate_cache1,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
## adding for kernel
|
||||
device,
|
||||
N,
|
||||
M,
|
||||
EM,
|
||||
K,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
num_slices,
|
||||
max_lora_rank,
|
||||
w1_output_dim_size,
|
||||
expand_block_size_m,
|
||||
expand_block_size_n,
|
||||
expand_block_size_k,
|
||||
expand_group_size_m,
|
||||
expand_num_warps,
|
||||
expand_num_stages,
|
||||
expand_split_k,
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
)
|
||||
|
||||
|
||||
def _fused_moe_lora_fake(
|
||||
output: torch.Tensor,
|
||||
qcurr_hidden_states: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
shrink_block_size_k: int,
|
||||
shrink_group_size_m: int,
|
||||
shrink_num_warps: int,
|
||||
shrink_num_stages: int,
|
||||
shrink_split_k: int,
|
||||
expand_block_size_m: int,
|
||||
expand_block_size_n: int,
|
||||
expand_block_size_k: int,
|
||||
expand_group_size_m: int,
|
||||
expand_num_warps: int,
|
||||
expand_num_stages: int,
|
||||
expand_split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _fused_moe_lora_shrink_fake(
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
qcurr_hidden_states: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _fused_moe_lora_expand_fake(
|
||||
output: torch.Tensor,
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
max_lora_rank: int,
|
||||
w1_output_dim_size: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="fused_moe_lora",
|
||||
op_func=_fused_moe_lora,
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fused_moe_lora_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_moe_lora_shrink",
|
||||
op_func=_fused_moe_lora_shrink,
|
||||
mutates_args=["a_intermediate_cache1"],
|
||||
fake_impl=_fused_moe_lora_shrink_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_moe_lora_expand",
|
||||
op_func=_fused_moe_lora_expand,
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fused_moe_lora_expand_fake,
|
||||
)
|
||||
|
||||
fused_moe_lora = torch.ops.vllm.fused_moe_lora
|
||||
fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink
|
||||
fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand
|
||||
|
||||
except AttributeError:
|
||||
fused_moe_lora = _fused_moe_lora
|
||||
fused_moe_lora_shrink = _fused_moe_lora_shrink
|
||||
fused_moe_lora_expand = _fused_moe_lora_expand
|
||||
340
vllm/lora/ops/triton_ops/kernel_utils.py
Normal file
340
vllm/lora/ops/triton_ops/kernel_utils.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Utilities for Punica kernel construction.
|
||||
"""
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mm_k(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
ak_stride,
|
||||
bk_stride,
|
||||
offset_k,
|
||||
K: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
b_dtype: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
base_k,
|
||||
):
|
||||
"""
|
||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||
B (k x n), iterate, through the K dimension to compute the partial/complete
|
||||
matrix block product.
|
||||
If SPLIT_K == 1, the output m x n product is complete.
|
||||
If SPLIT_K > 1, the thread block computes partial outputs. The partial
|
||||
outputs are then atomically summed in the caller code.
|
||||
Args:
|
||||
a_ptr: Array of pointers, identifying rows of A
|
||||
b_ptr: Array of pointers, identifying columns of B
|
||||
ak_stride: K dimension stride of the A matrix
|
||||
bk_stride: K dimension stride of the B matrix
|
||||
K: Length of the K dimension
|
||||
BLOCK_M: M dimension of the output block m x n
|
||||
BLOCK_N: N dimension of the output block m x n
|
||||
BLOCK_K: K dimension atom
|
||||
EVEN_K: True if the blocks of A and B can be loaded without any
|
||||
masking.
|
||||
SPLIT_K: Parameter signifying parallelism in the K dimension.
|
||||
CAST_TYPE: if True, cast the values from the A matrix to the B
|
||||
matrix dtype.
|
||||
b_dtype: datatype of the B matrix
|
||||
USE_GDC: Whether to use PDL. True indicates use.
|
||||
base_k: Base offset along K dimension for current SPLIT_K group
|
||||
"""
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
# Step size along K for each iteration
|
||||
STEP_K = BLOCK_K * SPLIT_K
|
||||
|
||||
# Total number of iterations (compile-time constant)
|
||||
num_iters = tl.cdiv(K, STEP_K)
|
||||
|
||||
for k in range(num_iters):
|
||||
# Current iteration's global K offset
|
||||
iter_k = k * STEP_K + base_k
|
||||
|
||||
# Check if this iteration is completely valid (no masking needed)
|
||||
block_end = iter_k + BLOCK_K
|
||||
|
||||
if EVEN_K:
|
||||
# K is divisible by BLOCK_K, no masking ever needed
|
||||
# pre-fetch lora weight
|
||||
tiled_b = tl.load(b_ptr)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
else:
|
||||
# Check if we need element-wise masking
|
||||
if iter_k >= K:
|
||||
# Entire block out of range, skip
|
||||
pass
|
||||
elif block_end <= K:
|
||||
# Entire block in range, no masking needed (fast path)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
else:
|
||||
# Partial block, need masking (only last iteration)
|
||||
k_offsets = tl.arange(0, BLOCK_K)
|
||||
mask = iter_k + k_offsets < K
|
||||
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(b_dtype)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += STEP_K * ak_stride
|
||||
b_ptr += STEP_K * bk_stride
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_expand_kernel(
|
||||
pid_n,
|
||||
lora_index,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice,
|
||||
compute the matrix product and store in the appropriate output location.
|
||||
Given that this is an expand kernel, we don't perform any split-K reduction
|
||||
as the K dimension is assumed to be small.
|
||||
"""
|
||||
|
||||
# ls_d*_ptr can be either an integer or a pointer
|
||||
if SAME_STRIDE:
|
||||
# integer
|
||||
cur_lora_d0_stride = ls_d0_ptr
|
||||
cur_lora_d1_stride = ls_d1_ptr
|
||||
cur_lora_d2_stride = ls_d2_ptr
|
||||
else:
|
||||
# pointer
|
||||
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
|
||||
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
|
||||
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
|
||||
|
||||
# Identify the input_ptr and lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
cur_input_ptr = input_ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
cur_input_ptr = input_ptr + slice_id * input_d0_stride
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(out_ptr.dtype.element_ty)
|
||||
)
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
a_ptr = (
|
||||
cur_input_ptr
|
||||
+ ram[:, None] * input_d1_stride
|
||||
+ offset_k[None, :] * input_d2_stride
|
||||
)
|
||||
b_ptr = (
|
||||
cur_lora_ptr
|
||||
+ cur_lora_d0_stride * lora_index
|
||||
+ offset_k[:, None] * cur_lora_d2_stride
|
||||
+ rbn[None, :] * cur_lora_d1_stride
|
||||
)
|
||||
|
||||
# Compute the block matrix product.
|
||||
SPLIT_K = 1
|
||||
|
||||
accumulator = mm_k(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
input_d2_stride,
|
||||
cur_lora_d2_stride,
|
||||
offset_k,
|
||||
K,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
CAST_TYPE,
|
||||
cur_lora_ptr.dtype.element_ty,
|
||||
USE_GDC,
|
||||
base_k=0,
|
||||
)
|
||||
|
||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||
if SLICE_NUM == 1:
|
||||
cur_slice_start = slice_start_loc
|
||||
else:
|
||||
cur_slice_start = tl.load(slice_start_loc + slice_id)
|
||||
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
c_ptr = (
|
||||
out_ptr
|
||||
+ ram[:, None] * output_d0_stride
|
||||
+ offset_cn[None, :] * output_d1_stride
|
||||
)
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < (cur_slice_start + N))
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_index,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
M_LEN,
|
||||
ram,
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Given an array of integers that identifies the rows of A, ram,
|
||||
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
|
||||
a slice_id that identifies the input/output slice, compute the
|
||||
matrix product and store in the appropriate output location.
|
||||
"""
|
||||
|
||||
# Identify the lora_ptr from slice_id.
|
||||
if SLICE_NUM == 1:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = lora_ptr
|
||||
else:
|
||||
# current lora ptr
|
||||
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
|
||||
tl.pointer_type(input_ptr.dtype.element_ty)
|
||||
)
|
||||
|
||||
# Identify the column indices of B to process.
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
# Identify A and B block pointers
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
a_ptr = (
|
||||
input_ptr + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride
|
||||
)
|
||||
b_ptr = (
|
||||
cur_lora_ptr
|
||||
+ lora_d0_stride * lora_index
|
||||
+ rbn[None, :] * lora_d1_stride
|
||||
+ offset_k[:, None] * lora_d2_stride
|
||||
)
|
||||
|
||||
# Compute partial/complete block matrix product.
|
||||
accumulator = mm_k(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
input_d1_stride,
|
||||
lora_d2_stride,
|
||||
offset_k,
|
||||
K,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
False,
|
||||
cur_lora_ptr.dtype.element_ty,
|
||||
False, # USE_GDC is always False in shrink kernel
|
||||
base_k=pid_sk * BLOCK_K,
|
||||
)
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
if USE_GDC:
|
||||
tl.extra.cuda.gdc_launch_dependents()
|
||||
# Identify the C output pointers to store the results of the accumulator.
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_cm = tl.arange(0, BLOCK_M)
|
||||
cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * output_d0_stride
|
||||
c_ptr = (
|
||||
cur_out_ptr
|
||||
+ ram[:, None] * output_d1_stride
|
||||
+ offset_cn[None, :] * output_d2_stride
|
||||
)
|
||||
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
|
||||
accumulator *= scaling
|
||||
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask, sem="relaxed")
|
||||
310
vllm/lora/ops/triton_ops/lora_expand_op.py
Normal file
310
vllm/lora/ops/triton_ops/lora_expand_op.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_loc,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride, # 1
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr, # 1
|
||||
output_d0_stride,
|
||||
output_d1_stride, # 1
|
||||
output_hs_ptr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
SAME_STRIDE: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
launch_pdl: tl.constexpr,
|
||||
):
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_mn = tl.program_id(axis=0)
|
||||
pid_m = pid_mn % cta_m_num
|
||||
pid_n = (pid_mn // cta_m_num) % cta_n_num
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# When the output dimensions of each slice are the same,cur_n=N, otherwise
|
||||
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
|
||||
# qkv linear.
|
||||
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
|
||||
if pid_n * BLOCK_N >= curr_N:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (
|
||||
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
|
||||
)
|
||||
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_expand_kernel(
|
||||
pid_n,
|
||||
lora_id,
|
||||
slice_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
curr_N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
slice_start_loc,
|
||||
# input ptr strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
input_d2_stride,
|
||||
# lora ptr strides
|
||||
ls_d0_ptr,
|
||||
ls_d1_ptr,
|
||||
ls_d2_ptr,
|
||||
# out ptr strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
# constants
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
SAME_STRIDE,
|
||||
SLICE_NUM,
|
||||
EVEN_K,
|
||||
CAST_TYPE,
|
||||
ADD_INPUTS,
|
||||
USE_GDC,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_expand(
|
||||
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
|
||||
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (list[torch.Tensor]): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
offset_start (int, optional): Offset start for output_tensor.
|
||||
Defaults to 0.
|
||||
add_inputs (bool, optional): Whether to add the input tensor to the
|
||||
output tensor. Defaults to False.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
for weight in lora_b_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(0) == len(lora_b_weights)
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check.
|
||||
M = inputs.size(1)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
(
|
||||
slice_start_tensor,
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
hidden_sizes_tensor,
|
||||
same_stride,
|
||||
MAX_N,
|
||||
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
|
||||
|
||||
K = lora_b_weights[0].shape[-1] # K= rank
|
||||
ADD_INPUTS = add_inputs
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
CAST_TYPE = False
|
||||
NUM_SLICES = len(lora_b_weights)
|
||||
|
||||
# Triton kernel configs.
|
||||
kernel_config = get_lora_op_configs(
|
||||
op_type="expand",
|
||||
max_loras=MAX_LORAS,
|
||||
batch=M,
|
||||
hidden_size=MAX_N,
|
||||
rank=K,
|
||||
num_slices=NUM_SLICES,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
BLOCK_M = kernel_config["block_m"]
|
||||
BLOCK_N = kernel_config["block_n"]
|
||||
BLOCK_K = kernel_config["block_k"]
|
||||
NUM_WARPS = kernel_config["num_warps"]
|
||||
NUM_CTAS = kernel_config["num_ctas"]
|
||||
NUM_STAGES = kernel_config["num_stages"]
|
||||
|
||||
EVEN_K = K % BLOCK_K == 0 # type: ignore
|
||||
|
||||
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only a few input tokens require
|
||||
# LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks simply exit.
|
||||
MAX_LORAS,
|
||||
)
|
||||
use_gdc = supports_pdl(inputs.device)
|
||||
_lora_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
MAX_N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
slice_start_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
inputs.stride(2),
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
hidden_sizes_tensor,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
NUM_SLICES,
|
||||
same_stride,
|
||||
use_gdc,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
launch_pdl=use_gdc,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_expand_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="lora_expand",
|
||||
op_func=_lora_expand,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_expand_fake,
|
||||
)
|
||||
lora_expand = torch.ops.vllm.lora_expand
|
||||
|
||||
except AttributeError:
|
||||
lora_expand = _lora_expand
|
||||
154
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
Normal file
154
vllm/lora/ops/triton_ops/lora_kernel_metadata.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
LoRA kernels metadata preparation utilities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAKernelMeta:
|
||||
token_lora_mapping: torch.Tensor
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor
|
||||
active_lora_ids: torch.Tensor
|
||||
num_tokens_per_lora: torch.Tensor
|
||||
lora_token_start_loc: torch.Tensor
|
||||
|
||||
# The V1 architecture uses the traced torch.compile graphs to execute
|
||||
# a forward pass. Things to note about this process,
|
||||
# 1. The tracing infers all python scalar datatype objects into a constant
|
||||
# value.
|
||||
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
|
||||
# is an experimental feature in pytorch)
|
||||
# 3. The internals of torch.ops functions are not traced.
|
||||
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
max_loras: int, max_num_tokens: int, device: torch.device | str
|
||||
) -> "LoRAKernelMeta":
|
||||
token_lora_mapping = torch.empty(
|
||||
max_num_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
token_indices_sorted_by_lora_ids = torch.empty(
|
||||
max_num_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# +1 because "no-lora" is also a possibility
|
||||
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
|
||||
# is a possibility.
|
||||
active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, device=device)
|
||||
|
||||
# using running example, [3, 10, 5, 2] is a possibility.
|
||||
num_tokens_per_lora = torch.zeros(
|
||||
max_loras + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# +2 for this because, the first index is always 0.
|
||||
# using running example, lora_token_start_loc
|
||||
# is [0, 3, 13, 18, 20].
|
||||
lora_token_start_loc = torch.zeros(
|
||||
max_loras + 2, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
|
||||
|
||||
return LoRAKernelMeta(
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||
active_lora_ids=active_lora_ids,
|
||||
num_tokens_per_lora=num_tokens_per_lora,
|
||||
lora_token_start_loc=lora_token_start_loc,
|
||||
no_lora_flag_cpu=no_lora_flag_cpu,
|
||||
)
|
||||
|
||||
def _reset(self):
|
||||
self.active_lora_ids.fill_(-1)
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
Prepare kernel metadata tensors for the current forward pass.
|
||||
|
||||
Args:
|
||||
token_lora_mapping (torch.Tensor): Tensor containing lora indices
|
||||
for each input token.
|
||||
"""
|
||||
|
||||
self._reset()
|
||||
|
||||
# Check and record no-lora case.
|
||||
no_lora = torch.all(token_lora_mapping == -1)
|
||||
self.no_lora_flag_cpu[0] = no_lora
|
||||
|
||||
if no_lora:
|
||||
# Early exit. LoRA kernels will not be run.
|
||||
return
|
||||
|
||||
num_tokens = token_lora_mapping.size(0)
|
||||
|
||||
# copy token lora mapping
|
||||
self.token_lora_mapping[:num_tokens].copy_(
|
||||
token_lora_mapping, non_blocking=True
|
||||
)
|
||||
|
||||
# token_indices_sorted_by_lora_ids
|
||||
_, token_indices_sorted_by_lora_ids = torch.sort(
|
||||
token_lora_mapping, stable=True
|
||||
)
|
||||
# start gpu transfer
|
||||
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
|
||||
token_indices_sorted_by_lora_ids, non_blocking=True
|
||||
)
|
||||
|
||||
# active_lora_ids, num_tokens_per_lora
|
||||
lora_ids, num_tokens_per_lora = torch.unique(
|
||||
token_lora_mapping, sorted=True, return_counts=True
|
||||
)
|
||||
self.active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True)
|
||||
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
|
||||
num_tokens_per_lora, non_blocking=True
|
||||
)
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
|
||||
lora_token_start_loc, non_blocking=True
|
||||
)
|
||||
|
||||
def meta_args(
|
||||
self, token_nums: int
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
forward pass execution of the kernel. The function returns all the
|
||||
metadata required by the kernel, in order, as a tuple, so it can be
|
||||
unpacked directly during the lora_shrink/lora_expand function call.
|
||||
|
||||
Args:
|
||||
token_nums (int): Number of input tokens in the current forward
|
||||
pass of the kernel.
|
||||
"""
|
||||
return (
|
||||
self.token_lora_mapping[:token_nums],
|
||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||
self.num_tokens_per_lora,
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
)
|
||||
287
vllm/lora/ops/triton_ops/lora_shrink_op.py
Normal file
287
vllm/lora/ops/triton_ops/lora_shrink_op.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
|
||||
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _lora_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
scaling,
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SLICE_NUM: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
launch_pdl: tl.constexpr,
|
||||
):
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||
|
||||
pid_sk_m_n = tl.program_id(axis=0)
|
||||
pid_sk = pid_sk_m_n % SPLIT_K
|
||||
|
||||
pid_m_n = pid_sk_m_n // SPLIT_K
|
||||
num_pid_in_group = GROUP_SIZE_M * cta_n_num
|
||||
group_id = pid_m_n // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
|
||||
|
||||
# Column-major ordering within groups for better cache reuse
|
||||
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
|
||||
|
||||
slice_id = tl.program_id(axis=1)
|
||||
lora_idx = tl.program_id(axis=2)
|
||||
|
||||
lora_id = tl.load(lora_ids + lora_idx)
|
||||
if lora_id == -1:
|
||||
# Early exit for the no-lora case.
|
||||
return
|
||||
|
||||
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
|
||||
|
||||
cta_m_offset = pid_m * BLOCK_M
|
||||
if cta_m_offset >= lora_m_size:
|
||||
# Early exit CTA.
|
||||
return
|
||||
|
||||
# num rows this CTA should process.
|
||||
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
|
||||
|
||||
# Identify all rows that this CTA should process.
|
||||
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
|
||||
cta_lora_seq_indices = (
|
||||
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
|
||||
)
|
||||
# Load all relevant row indices.
|
||||
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
|
||||
ram = tl.load(cta_lora_seq_indices + offset_m)
|
||||
|
||||
do_shrink_kernel(
|
||||
pid_n,
|
||||
pid_sk,
|
||||
slice_id,
|
||||
lora_id,
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
cta_m_len,
|
||||
ram, # array identifying the rows of Input ptr to operate on
|
||||
# input strides
|
||||
input_d0_stride,
|
||||
input_d1_stride,
|
||||
# lora strides
|
||||
lora_d0_stride,
|
||||
lora_d1_stride,
|
||||
lora_d2_stride,
|
||||
# output strides
|
||||
output_d0_stride,
|
||||
output_d1_stride,
|
||||
output_d2_stride,
|
||||
scaling,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
SLICE_NUM,
|
||||
USE_GDC,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _lora_shrink(
|
||||
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
|
||||
lora_a_weights: list[torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
|
||||
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
|
||||
token_lora_mapping: torch.Tensor, # shape [num_tokens]
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
|
||||
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor
|
||||
lora_a_weights (list[torch.Tensor]): LoRA weights
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
token_lora_mapping (torch.Tensor): A tensor mapping each input token
|
||||
to the lora-id related to that token. A value of -1 indicates that
|
||||
LoRA doesn't apply to that token.
|
||||
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
|
||||
the A matrix grouped by LoRA IDs.
|
||||
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
|
||||
of tokens that are to be processed by LoRA ID lora_ids[i]
|
||||
lora_token_start_loc (torch.Tensor): A cumulative sum of
|
||||
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
|
||||
lora_token_start_loc[i], along with num_tokens_per_lora[i]
|
||||
identifies the region in token_indices_sorted_by_lora_ids that
|
||||
LoRA lora_ids[i] should process.
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
|
||||
assert no_lora_flag_cpu.numel() == 1
|
||||
if no_lora_flag_cpu.item():
|
||||
# None of the inputs require LoRA.
|
||||
return
|
||||
|
||||
assert inputs.dtype == lora_a_weights[0].dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
for weight in lora_a_weights:
|
||||
assert weight.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
assert inputs.size(1) == lora_a_weights[0].size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
# metadata sanity check
|
||||
M = inputs.size(0)
|
||||
assert token_lora_mapping.size(0) == M
|
||||
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
|
||||
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
|
||||
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
|
||||
|
||||
output_tensor.zero_()
|
||||
|
||||
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
|
||||
_get_lora_a_ptr(lora_a_weights, inputs.device)
|
||||
)
|
||||
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
|
||||
NUM_SLICES = len(lora_a_weights)
|
||||
MAX_LORAS = lora_ids.size(0)
|
||||
|
||||
# Triton kernel configs
|
||||
kernel_config = get_lora_op_configs(
|
||||
"shrink",
|
||||
max_loras=MAX_LORAS,
|
||||
batch=M,
|
||||
hidden_size=K,
|
||||
rank=N,
|
||||
num_slices=NUM_SLICES,
|
||||
)
|
||||
BLOCK_M = kernel_config["block_m"]
|
||||
BLOCK_N = kernel_config["block_n"]
|
||||
BLOCK_K = kernel_config["block_k"]
|
||||
SPLIT_K = kernel_config["split_k"]
|
||||
NUM_WARPS = kernel_config["num_warps"]
|
||||
NUM_STAGES = kernel_config["num_stages"]
|
||||
NUM_CTAS = kernel_config["num_ctas"]
|
||||
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
|
||||
|
||||
# TODO (varun): This grid formulation maximizes parallelization at the
|
||||
# cost of wasteful thread block launch when only few of the input tokens
|
||||
# require LoRA. This might not be the best in all cases.
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
# Each LoRA receives its own set of thread blocks for output
|
||||
# computation. If some LoRA doesn't have any tokens to process, its
|
||||
# thread blocks exit early.
|
||||
MAX_LORAS,
|
||||
)
|
||||
use_gdc = supports_pdl(inputs.device)
|
||||
_lora_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_ptr_tensor,
|
||||
output_tensor,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
token_indices_sorted_by_lora_ids,
|
||||
num_tokens_per_lora,
|
||||
lora_token_start_loc,
|
||||
lora_ids,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_strides_d0,
|
||||
lora_strides_d1,
|
||||
lora_strides_d2,
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
output_tensor.stride(2),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
GROUP_SIZE_M,
|
||||
NUM_SLICES,
|
||||
use_gdc,
|
||||
num_warps=NUM_WARPS,
|
||||
num_ctas=NUM_CTAS,
|
||||
num_stages=NUM_STAGES,
|
||||
launch_pdl=use_gdc,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _lora_shrink_fake(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: list[torch.Tensor],
|
||||
output_tensor: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
token_indices_sorted_by_lora_ids: torch.Tensor,
|
||||
num_tokens_per_lora: torch.Tensor,
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="lora_shrink",
|
||||
op_func=_lora_shrink,
|
||||
mutates_args=["output_tensor"],
|
||||
fake_impl=_lora_shrink_fake,
|
||||
)
|
||||
lora_shrink = torch.ops.vllm.lora_shrink
|
||||
|
||||
except AttributeError:
|
||||
lora_shrink = _lora_shrink
|
||||
295
vllm/lora/ops/triton_ops/utils.py
Normal file
295
vllm/lora/ops/triton_ops/utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
|
||||
|
||||
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
|
||||
"""
|
||||
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
"""
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
|
||||
|
||||
if values := _LORA_A_PTR_DICT.get(key):
|
||||
return values
|
||||
|
||||
lora_strides_d0 = []
|
||||
lora_strides_d1 = []
|
||||
lora_strides_d2 = []
|
||||
tensor_ptrs = []
|
||||
for lora_a_weight in lora_a_weights:
|
||||
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_a_weight.size(1) == 1
|
||||
lora_a_weight = lora_a_weight.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_a_weight.is_contiguous()
|
||||
tensor_ptrs.append(lora_a_weight.data_ptr())
|
||||
lora_strides_d0.append(lora_a_weight.stride(0))
|
||||
lora_strides_d1.append(lora_a_weight.stride(1))
|
||||
lora_strides_d2.append(lora_a_weight.stride(2))
|
||||
if len(lora_a_weights) > 1:
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
|
||||
else:
|
||||
lora_ptr_tensor = lora_a_weights[0]
|
||||
|
||||
if (
|
||||
len(set(lora_strides_d0)) > 1
|
||||
or len(set(lora_strides_d1)) > 1
|
||||
or len(set(lora_strides_d2)) > 1
|
||||
):
|
||||
raise ValueError("All LoRA weights must have the same stride.")
|
||||
|
||||
_LORA_A_PTR_DICT[key] = (
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0[0],
|
||||
lora_strides_d1[0],
|
||||
lora_strides_d2[0],
|
||||
)
|
||||
return _LORA_A_PTR_DICT.get(key)
|
||||
|
||||
|
||||
def _get_lora_b_ptr(
|
||||
lora_weights: list[torch.Tensor], offset_start: int, device: torch.device
|
||||
):
|
||||
"""
|
||||
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
|
||||
"""
|
||||
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
||||
if values := _LORA_B_PTR_DICT.get(key):
|
||||
return values
|
||||
slice_offset_lst = []
|
||||
tensor_ptrs = []
|
||||
lora_strides_d0 = []
|
||||
lora_strides_d1 = []
|
||||
lora_strides_d2 = []
|
||||
hidden_sizes = []
|
||||
slice_offset = offset_start
|
||||
for lora_b_weight in lora_weights:
|
||||
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weight.size(1) == 1
|
||||
lora_b_weight = lora_b_weight.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_b_weight.is_contiguous()
|
||||
tensor_ptrs.append(lora_b_weight.data_ptr())
|
||||
lora_strides_d0.append(lora_b_weight.stride(0))
|
||||
lora_strides_d1.append(lora_b_weight.stride(1))
|
||||
lora_strides_d2.append(lora_b_weight.stride(2))
|
||||
slice_offset_lst.append(slice_offset)
|
||||
slice_offset += lora_b_weight.size(1)
|
||||
hidden_sizes.append(lora_b_weight.size(1))
|
||||
|
||||
if len(lora_weights) > 1:
|
||||
# note these are device tensors
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
|
||||
slice_start_tensor = torch.tensor(
|
||||
slice_offset_lst, device=device, dtype=torch.uint64
|
||||
)
|
||||
else:
|
||||
slice_start_tensor = slice_offset_lst[0]
|
||||
lora_ptr_tensor = lora_b_weight[0]
|
||||
|
||||
# If each lora has the same stride, there's no need to use a
|
||||
# tensor for storage.
|
||||
if (
|
||||
len(set(lora_strides_d0)) == 1
|
||||
and len(set(lora_strides_d1)) == 1
|
||||
and len(set(lora_strides_d2)) == 1
|
||||
) and len(set(hidden_sizes)) == 1:
|
||||
lora_strides_d0_tensor = lora_strides_d0[0]
|
||||
lora_strides_d1_tensor = lora_strides_d1[0]
|
||||
lora_strides_d2_tensor = lora_strides_d2[0]
|
||||
hidden_sizes_tensor = hidden_sizes[0]
|
||||
same_stride = True
|
||||
|
||||
else:
|
||||
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
|
||||
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
|
||||
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
|
||||
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
|
||||
same_stride = False
|
||||
# MAX_N is the maximum hidden size among all the lora_b weights
|
||||
MAX_N = max(hidden_sizes)
|
||||
_LORA_B_PTR_DICT[key] = (
|
||||
slice_start_tensor,
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
hidden_sizes_tensor,
|
||||
same_stride,
|
||||
MAX_N,
|
||||
)
|
||||
return _LORA_B_PTR_DICT.get(key)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = gpu_name.replace(" ", "_")
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
|
||||
config_fname = None
|
||||
# only expand op needs to consider add_inputs
|
||||
if op_type == "expand":
|
||||
config_fname = (
|
||||
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
|
||||
)
|
||||
else:
|
||||
config_fname = f"{gpu_name}_{op_type.upper()}.json"
|
||||
|
||||
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
|
||||
if not config_path.exists():
|
||||
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
|
||||
return None
|
||||
|
||||
# Load json
|
||||
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
|
||||
with open(str(config_path)) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
config_data = None
|
||||
|
||||
return config_data
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_lora_op_configs(
|
||||
op_type: str,
|
||||
max_loras: int,
|
||||
batch: int,
|
||||
hidden_size: int,
|
||||
rank: int,
|
||||
num_slices: int,
|
||||
add_inputs: bool | None = None,
|
||||
moe_intermediate_size: int | None = None,
|
||||
) -> dict[str, int | None]:
|
||||
# Add support for fused_moe_lora ops
|
||||
assert op_type in [
|
||||
"shrink",
|
||||
"expand",
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]
|
||||
|
||||
# default config
|
||||
default = {}
|
||||
if op_type == "shrink":
|
||||
default = {
|
||||
"block_m": 32,
|
||||
"block_n": 16,
|
||||
"block_k": 256 if batch < 128 else 32,
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"group_size_m": 8,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
# The default config for fused_moe_lora ops
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 64,
|
||||
"block_k": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
"group_size_m": 8,
|
||||
"split_k": 1,
|
||||
}
|
||||
else:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 128,
|
||||
"block_k": 16,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
m = batch
|
||||
|
||||
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
|
||||
|
||||
config_data: Any
|
||||
config_data = load_lora_op_config(op_type, add_inputs)
|
||||
if not config_data:
|
||||
logger.warning_once("Using default LoRA kernel configs")
|
||||
return default
|
||||
|
||||
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
|
||||
# slice by max_loras
|
||||
config_data = (
|
||||
config_data.get(str(max_loras))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
|
||||
)
|
||||
# slice by num_slices
|
||||
config_data = config_data[str(num_slices)]
|
||||
# slice by m
|
||||
config_data = (
|
||||
config_data.get(str(m))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
|
||||
)
|
||||
# slice by k
|
||||
config_data = (
|
||||
config_data.get(str(k))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
|
||||
)
|
||||
# slice by n
|
||||
config_data = (
|
||||
config_data.get(str(n))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
|
||||
)
|
||||
|
||||
# slice by moe-intermediate-size if applicable
|
||||
if moe_intermediate_size is not None:
|
||||
i = moe_intermediate_size
|
||||
config_data = (
|
||||
config_data.get(str(i))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))]
|
||||
)
|
||||
|
||||
assert config_data is not None
|
||||
return config_data
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_pdl(device: torch.device | None = None) -> bool:
|
||||
"""
|
||||
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
|
||||
"""
|
||||
# PDL requires compute capability SM90 or above
|
||||
return current_platform.is_cuda() and current_platform.has_device_capability(90)
|
||||
6
vllm/lora/ops/xla_ops/__init__.py
Normal file
6
vllm/lora/ops/xla_ops/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
|
||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
||||
141
vllm/lora/ops/xla_ops/lora_ops.py
Normal file
141
vllm/lora/ops/xla_ops/lora_ops.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.core.xla_builder as xb
|
||||
from torch.library import impl
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
|
||||
|
||||
|
||||
@jax.jit
|
||||
def bgmv_jax(inputs, loras, idxs):
|
||||
return jnp.einsum(
|
||||
"td,tX,Xld->tl",
|
||||
inputs,
|
||||
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
|
||||
loras,
|
||||
)
|
||||
|
||||
|
||||
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "XLA")
|
||||
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
|
||||
jax_import_guard()
|
||||
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
|
||||
|
||||
|
||||
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
|
||||
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
|
||||
T, _ = inputs.shape
|
||||
if len(loras.shape) == 4:
|
||||
loras = loras.squeeze(axis=1)
|
||||
_, L, _ = loras.shape
|
||||
|
||||
return torch.empty((T, L), device=inputs.device)
|
||||
|
||||
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
if output_tensor.shape[1] > outputs.shape[1]:
|
||||
outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs[:limit, : output_tensor.shape[1]]
|
||||
else:
|
||||
return outputs[:limit, : output_tensor.shape[1]]
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
scaling (float, optional): Scalar multiplier applied to the output.
|
||||
"""
|
||||
|
||||
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
|
||||
|
||||
lora_b_weights (torch.Tensor): LoRA weights of shape
|
||||
[num_loras, lora_rank, hidden_size].
|
||||
|
||||
output_tensor (torch.Tensor): output tensor of shape
|
||||
[num_tokens, hidden_size * num_slices].
|
||||
|
||||
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
|
||||
indicating which LoRA matrix to use for each token.
|
||||
add_inputs (bool): Whether or not to add the input tensor to the output
|
||||
tensor.
|
||||
"""
|
||||
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
|
||||
|
||||
outputs = F.pad(
|
||||
outputs,
|
||||
(
|
||||
slice_offset,
|
||||
output_tensor.shape[1] - (slice_offset + slice_size),
|
||||
0,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
if add_inputs:
|
||||
return output_tensor + outputs
|
||||
else:
|
||||
return outputs
|
||||
Reference in New Issue
Block a user