Sync from v0.13
This commit is contained in:
295
vllm/lora/ops/triton_ops/utils.py
Normal file
295
vllm/lora/ops/triton_ops/utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
|
||||
|
||||
|
||||
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
|
||||
"""
|
||||
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
"""
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
|
||||
|
||||
if values := _LORA_A_PTR_DICT.get(key):
|
||||
return values
|
||||
|
||||
lora_strides_d0 = []
|
||||
lora_strides_d1 = []
|
||||
lora_strides_d2 = []
|
||||
tensor_ptrs = []
|
||||
for lora_a_weight in lora_a_weights:
|
||||
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_a_weight.size(1) == 1
|
||||
lora_a_weight = lora_a_weight.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_a_weight.is_contiguous()
|
||||
tensor_ptrs.append(lora_a_weight.data_ptr())
|
||||
lora_strides_d0.append(lora_a_weight.stride(0))
|
||||
lora_strides_d1.append(lora_a_weight.stride(1))
|
||||
lora_strides_d2.append(lora_a_weight.stride(2))
|
||||
if len(lora_a_weights) > 1:
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
|
||||
else:
|
||||
lora_ptr_tensor = lora_a_weights[0]
|
||||
|
||||
if (
|
||||
len(set(lora_strides_d0)) > 1
|
||||
or len(set(lora_strides_d1)) > 1
|
||||
or len(set(lora_strides_d2)) > 1
|
||||
):
|
||||
raise ValueError("All LoRA weights must have the same stride.")
|
||||
|
||||
_LORA_A_PTR_DICT[key] = (
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0[0],
|
||||
lora_strides_d1[0],
|
||||
lora_strides_d2[0],
|
||||
)
|
||||
return _LORA_A_PTR_DICT.get(key)
|
||||
|
||||
|
||||
def _get_lora_b_ptr(
|
||||
lora_weights: list[torch.Tensor], offset_start: int, device: torch.device
|
||||
):
|
||||
"""
|
||||
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
|
||||
After this, it remains constant and subsequent usage is through LUT.
|
||||
Refer to:
|
||||
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
|
||||
|
||||
"""
|
||||
|
||||
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
|
||||
if values := _LORA_B_PTR_DICT.get(key):
|
||||
return values
|
||||
slice_offset_lst = []
|
||||
tensor_ptrs = []
|
||||
lora_strides_d0 = []
|
||||
lora_strides_d1 = []
|
||||
lora_strides_d2 = []
|
||||
hidden_sizes = []
|
||||
slice_offset = offset_start
|
||||
for lora_b_weight in lora_weights:
|
||||
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weight.size(1) == 1
|
||||
lora_b_weight = lora_b_weight.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_b_weight.is_contiguous()
|
||||
tensor_ptrs.append(lora_b_weight.data_ptr())
|
||||
lora_strides_d0.append(lora_b_weight.stride(0))
|
||||
lora_strides_d1.append(lora_b_weight.stride(1))
|
||||
lora_strides_d2.append(lora_b_weight.stride(2))
|
||||
slice_offset_lst.append(slice_offset)
|
||||
slice_offset += lora_b_weight.size(1)
|
||||
hidden_sizes.append(lora_b_weight.size(1))
|
||||
|
||||
if len(lora_weights) > 1:
|
||||
# note these are device tensors
|
||||
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
|
||||
slice_start_tensor = torch.tensor(
|
||||
slice_offset_lst, device=device, dtype=torch.uint64
|
||||
)
|
||||
else:
|
||||
slice_start_tensor = slice_offset_lst[0]
|
||||
lora_ptr_tensor = lora_b_weight[0]
|
||||
|
||||
# If each lora has the same stride, there's no need to use a
|
||||
# tensor for storage.
|
||||
if (
|
||||
len(set(lora_strides_d0)) == 1
|
||||
and len(set(lora_strides_d1)) == 1
|
||||
and len(set(lora_strides_d2)) == 1
|
||||
) and len(set(hidden_sizes)) == 1:
|
||||
lora_strides_d0_tensor = lora_strides_d0[0]
|
||||
lora_strides_d1_tensor = lora_strides_d1[0]
|
||||
lora_strides_d2_tensor = lora_strides_d2[0]
|
||||
hidden_sizes_tensor = hidden_sizes[0]
|
||||
same_stride = True
|
||||
|
||||
else:
|
||||
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
|
||||
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
|
||||
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
|
||||
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
|
||||
same_stride = False
|
||||
# MAX_N is the maximum hidden size among all the lora_b weights
|
||||
MAX_N = max(hidden_sizes)
|
||||
_LORA_B_PTR_DICT[key] = (
|
||||
slice_start_tensor,
|
||||
lora_ptr_tensor,
|
||||
lora_strides_d0_tensor,
|
||||
lora_strides_d1_tensor,
|
||||
lora_strides_d2_tensor,
|
||||
hidden_sizes_tensor,
|
||||
same_stride,
|
||||
MAX_N,
|
||||
)
|
||||
return _LORA_B_PTR_DICT.get(key)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
|
||||
if user_defined_config_folder is not None:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
gpu_name = gpu_name.replace(" ", "_")
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
|
||||
config_fname = None
|
||||
# only expand op needs to consider add_inputs
|
||||
if op_type == "expand":
|
||||
config_fname = (
|
||||
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
|
||||
)
|
||||
else:
|
||||
config_fname = f"{gpu_name}_{op_type.upper()}.json"
|
||||
|
||||
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
|
||||
if not config_path.exists():
|
||||
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
|
||||
return None
|
||||
|
||||
# Load json
|
||||
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
|
||||
with open(str(config_path)) as f:
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
config_data = None
|
||||
|
||||
return config_data
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_lora_op_configs(
|
||||
op_type: str,
|
||||
max_loras: int,
|
||||
batch: int,
|
||||
hidden_size: int,
|
||||
rank: int,
|
||||
num_slices: int,
|
||||
add_inputs: bool | None = None,
|
||||
moe_intermediate_size: int | None = None,
|
||||
) -> dict[str, int | None]:
|
||||
# Add support for fused_moe_lora ops
|
||||
assert op_type in [
|
||||
"shrink",
|
||||
"expand",
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]
|
||||
|
||||
# default config
|
||||
default = {}
|
||||
if op_type == "shrink":
|
||||
default = {
|
||||
"block_m": 32,
|
||||
"block_n": 16,
|
||||
"block_k": 256 if batch < 128 else 32,
|
||||
"split_k": 64 if batch < 128 else 8,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"group_size_m": 8,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
# The default config for fused_moe_lora ops
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 64,
|
||||
"block_k": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
"group_size_m": 8,
|
||||
"split_k": 1,
|
||||
}
|
||||
else:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 128,
|
||||
"block_k": 16,
|
||||
"num_warps": 4,
|
||||
"num_ctas": 1,
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
m = batch
|
||||
|
||||
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
|
||||
|
||||
config_data: Any
|
||||
config_data = load_lora_op_config(op_type, add_inputs)
|
||||
if not config_data:
|
||||
logger.warning_once("Using default LoRA kernel configs")
|
||||
return default
|
||||
|
||||
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
|
||||
# slice by max_loras
|
||||
config_data = (
|
||||
config_data.get(str(max_loras))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
|
||||
)
|
||||
# slice by num_slices
|
||||
config_data = config_data[str(num_slices)]
|
||||
# slice by m
|
||||
config_data = (
|
||||
config_data.get(str(m))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
|
||||
)
|
||||
# slice by k
|
||||
config_data = (
|
||||
config_data.get(str(k))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
|
||||
)
|
||||
# slice by n
|
||||
config_data = (
|
||||
config_data.get(str(n))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
|
||||
)
|
||||
|
||||
# slice by moe-intermediate-size if applicable
|
||||
if moe_intermediate_size is not None:
|
||||
i = moe_intermediate_size
|
||||
config_data = (
|
||||
config_data.get(str(i))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))]
|
||||
)
|
||||
|
||||
assert config_data is not None
|
||||
return config_data
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_pdl(device: torch.device | None = None) -> bool:
|
||||
"""
|
||||
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
|
||||
"""
|
||||
# PDL requires compute capability SM90 or above
|
||||
return current_platform.is_cuda() and current_platform.has_device_capability(90)
|
||||
Reference in New Issue
Block a user