Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
FusedMoEKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
MoEPrepareAndFinalizeNoDPEPModular,
|
||||
)
|
||||
|
||||
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
|
||||
@@ -83,7 +83,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
):
|
||||
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size = layer.intermediate_size_per_partition
|
||||
intermediate_size = (
|
||||
self.w2_lora_a_stacked[0].shape[-1]
|
||||
if op_prefix == "w2"
|
||||
else self.w13_lora_b_stacked[0].shape[-2]
|
||||
)
|
||||
shrink_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
||||
max_loras=num_loras,
|
||||
@@ -132,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
|
||||
# Use the existing modular kernel from the quant method
|
||||
m_fused_moe_fn = self.base_layer.quant_method.moe_mk
|
||||
m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
|
||||
# Don't let the kernel own shared experts so the runner can
|
||||
# overlap them with routed experts via a separate CUDA stream.
|
||||
m_fused_moe_fn.shared_experts = None
|
||||
@@ -140,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
# Create a new modular kernel via select_gemm_impl.
|
||||
# Don't pass shared_experts to the kernel so the runner can
|
||||
# overlap them with routed experts via a separate CUDA stream.
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
m_fused_moe_fn = FusedMoEModularKernel(
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
|
||||
m_fused_moe_fn = FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
self.base_layer.quant_method.select_gemm_impl(
|
||||
prepare_finalize, self.base_layer
|
||||
@@ -150,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
assert isinstance(
|
||||
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
|
||||
m_fused_moe_fn.impl.fused_experts,
|
||||
(MarlinExperts, UnfusedOAITritonExperts),
|
||||
)
|
||||
else:
|
||||
assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts)
|
||||
assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
|
||||
|
||||
def fwd_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -333,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
return wrapper
|
||||
|
||||
fused_experts = m_fused_moe_fn.fused_experts
|
||||
fused_experts = m_fused_moe_fn.impl.fused_experts
|
||||
|
||||
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
|
||||
m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
|
||||
fused_experts.activation = act_decorator(
|
||||
self.base_layer, fused_experts.activation
|
||||
)
|
||||
|
||||
@@ -88,10 +88,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
# TODO: Verify if this condition can be further relaxed
|
||||
if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048:
|
||||
raise ValueError(
|
||||
"When using LoRA, vocab size must be > 32000 and <= 258048"
|
||||
)
|
||||
if self.base_layer.vocab_size > 258048:
|
||||
raise ValueError("When using LoRA, vocab size must be <= 258048")
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
|
||||
@@ -8,9 +8,10 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.triton_utils.allocation import set_triton_allocator
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .utils import supports_pdl
|
||||
from .utils import supports_pdl, supports_tma
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -70,6 +71,37 @@ def _get_token_offs(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _get_c_ptrs(
|
||||
cur_c_ptr,
|
||||
lora_id,
|
||||
pid_m,
|
||||
offs,
|
||||
offs_token,
|
||||
offs_cn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
EM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
sort_c: tl.constexpr,
|
||||
):
|
||||
# When sort_c is true, store the output in c_ptr using token order defined
|
||||
# in sorted_token_ids_ptr; otherwise, use the original token order from the prompt
|
||||
if sort_c:
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + offs
|
||||
c_ptrs = (
|
||||
cur_c_ptr
|
||||
+ lora_id * EM * stride_cm
|
||||
+ stride_cm * offs_token_id[:, None]
|
||||
+ stride_cn * offs_cn[None, :]
|
||||
)
|
||||
else:
|
||||
c_ptrs = (
|
||||
cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
)
|
||||
return c_ptrs
|
||||
|
||||
|
||||
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
|
||||
|
||||
|
||||
@@ -95,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
|
||||
|
||||
def _adjust_kernel_inputs(
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
):
|
||||
@@ -109,7 +141,7 @@ def _adjust_kernel_inputs(
|
||||
else:
|
||||
stride_tl = sorted_token_ids.stride(0)
|
||||
stride_el = expert_ids.stride(0)
|
||||
grid_lora_dim = num_active_loras
|
||||
grid_lora_dim = num_active_loras.item()
|
||||
return grid_lora_dim, stride_tl, stride_el
|
||||
|
||||
|
||||
@@ -125,7 +157,9 @@ def _adjust_kernel_inputs(
|
||||
)
|
||||
def _fused_moe_lora_kernel(
|
||||
a_ptr,
|
||||
a_desc,
|
||||
b_ptr,
|
||||
b_desc,
|
||||
c_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
@@ -177,6 +211,18 @@ def _fused_moe_lora_kernel(
|
||||
USE_GDC: tl.constexpr,
|
||||
launch_pdl: tl.constexpr,
|
||||
IS_PRIMARY: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
# sort_c determines whether tokens are stored in C in the order determined
|
||||
# by sorted_token_ids to enable later TMA loads from this tensor.
|
||||
#
|
||||
# When USE_TMA is enabled, the parameter combinations are:
|
||||
# a_desc | b_desc | sort_c | Use Case
|
||||
# --------|---------|--------|-----------------------------
|
||||
# yes | yes | False | expand kernel (num_slices=1)
|
||||
# no | yes | True | shrink kernel (num_slices=1)
|
||||
# yes | no | False | expand kernel (num_slices>1)
|
||||
# no | no | True | shrink kernel (num_slices>1)
|
||||
sort_c: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
slice_id = tl.program_id(axis=1)
|
||||
@@ -250,58 +296,90 @@ def _fused_moe_lora_kernel(
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
# remove modulo wrap-around
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
# get a_ptrs,b_ptrs
|
||||
a_ptrs = cur_a_ptr + (
|
||||
offs_token[:, None] // token_mapping_factor * stride_am
|
||||
+ offs_k[None, :] * stride_ak
|
||||
)
|
||||
if USE_TMA and a_desc is not None:
|
||||
# Expand path - with TMA enabled, load from A using TMA descriptor
|
||||
offs_am = (
|
||||
slice_id * max_loras * EM
|
||||
+ lora_id * EM
|
||||
+ pid_m * BLOCK_SIZE_M // token_mapping_factor
|
||||
)
|
||||
offs_ak = pid_sk * BLOCK_SIZE_K
|
||||
else:
|
||||
# Shrink path - load hidden states based on order defined in
|
||||
# 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order
|
||||
tl.static_assert(a_desc is None, "a_desc must be none")
|
||||
a_ptrs = cur_a_ptr + (
|
||||
offs_token[:, None] // token_mapping_factor * stride_am
|
||||
+ offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_id * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
if USE_TMA:
|
||||
offs_bn = pid_n * BLOCK_SIZE_N
|
||||
offs_bk = pid_sk * BLOCK_SIZE_K
|
||||
if b_desc is None:
|
||||
# Note(@gnovack) - Allocation of TMA descriptors on-device
|
||||
# can cause conflicts when running in parallel via PDL
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
|
||||
b_desc = tl.make_tensor_descriptor(
|
||||
cur_b_ptr,
|
||||
shape=[max_loras, num_experts, N, K],
|
||||
strides=[stride_bl, stride_be, stride_bn, stride_bk],
|
||||
block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K],
|
||||
)
|
||||
else:
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
|
||||
b_ptrs = (
|
||||
cur_b_ptr
|
||||
+ lora_id * stride_bl
|
||||
+ expert_id * stride_be
|
||||
+ offs_k[:, None] * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
)
|
||||
|
||||
if USE_GDC and IS_PRIMARY:
|
||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||
tl.extra.cuda.gdc_launch_dependents()
|
||||
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
|
||||
for k in range(0, grid_k):
|
||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
k_remaining = K - cur_k_offset
|
||||
# pre-fetch lora weight
|
||||
# add (offs_bn < N) mask; optional .ca for B
|
||||
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
|
||||
if USE_B_L2_CACHE:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
|
||||
if b_desc is not None:
|
||||
b = (
|
||||
b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset])
|
||||
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
|
||||
.T
|
||||
)
|
||||
else:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||
# add (offs_bn < N) mask; optional .ca for B
|
||||
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
|
||||
if USE_B_L2_CACHE:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
|
||||
else:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if a_desc is not None:
|
||||
a = a_desc.load([offs_am, offs_ak + cur_k_offset])
|
||||
else:
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
other=0.0,
|
||||
)
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
other=0.0,
|
||||
)
|
||||
accumulator += tl.dot(a, b)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
|
||||
@@ -309,7 +387,19 @@ def _fused_moe_lora_kernel(
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_ptrs = _get_c_ptrs(
|
||||
cur_c_ptr,
|
||||
lora_id,
|
||||
pid_m,
|
||||
offs,
|
||||
offs_token,
|
||||
offs_cn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
EM,
|
||||
BLOCK_SIZE_M,
|
||||
sort_c,
|
||||
)
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
if SPLIT_K == 1:
|
||||
@@ -354,9 +444,10 @@ def _fused_moe_lora_shrink(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
shrink_config = {
|
||||
@@ -369,6 +460,7 @@ def _fused_moe_lora_shrink(
|
||||
"SPLIT_K": split_k,
|
||||
"USE_GDC": use_gdc,
|
||||
"launch_pdl": use_gdc, # triton kernel metadata
|
||||
"USE_TMA": use_tma,
|
||||
}
|
||||
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
@@ -383,9 +475,20 @@ def _fused_moe_lora_shrink(
|
||||
len(lora_a_stacked),
|
||||
grid_lora_dim,
|
||||
)
|
||||
|
||||
a_desc = None
|
||||
b_desc = None
|
||||
if use_tma and num_slices == 1:
|
||||
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
lora_a_stacked[0],
|
||||
[1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]],
|
||||
)
|
||||
|
||||
_fused_moe_lora_kernel[grid](
|
||||
qcurr_hidden_states,
|
||||
a_desc,
|
||||
b_ptr,
|
||||
b_desc,
|
||||
a_intermediate_cache1,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
@@ -407,8 +510,8 @@ def _fused_moe_lora_shrink(
|
||||
w1_lora_a_stacked.stride(1),
|
||||
w1_lora_a_stacked.stride(3),
|
||||
w1_lora_a_stacked.stride(2),
|
||||
a_intermediate_cache1.stride(2),
|
||||
a_intermediate_cache1.stride(3),
|
||||
a_intermediate_cache1.stride(-2),
|
||||
a_intermediate_cache1.stride(-1),
|
||||
stride_tl,
|
||||
stride_el,
|
||||
slice_a_size=qcurr_hidden_states.numel(),
|
||||
@@ -419,7 +522,8 @@ def _fused_moe_lora_shrink(
|
||||
naive_block_assignment=sorted_token_ids is None,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
ADD_INPUTS=False,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
USE_B_L2_CACHE=True,
|
||||
sort_c=use_tma and sorted_token_ids is not None,
|
||||
IS_PRIMARY=True,
|
||||
**shrink_config,
|
||||
)
|
||||
@@ -458,10 +562,11 @@ def _fused_moe_lora_expand(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||
K = max_lora_rank
|
||||
@@ -470,7 +575,7 @@ def _fused_moe_lora_expand(
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
|
||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
-1, a_intermediate_cache1.shape[-1]
|
||||
)
|
||||
|
||||
expand_config = {
|
||||
@@ -483,6 +588,7 @@ def _fused_moe_lora_expand(
|
||||
"SPLIT_K": 1, # Set split_k = 1 for expand calls
|
||||
"USE_GDC": use_gdc,
|
||||
"launch_pdl": use_gdc, # triton kernel metadata
|
||||
"USE_TMA": use_tma,
|
||||
}
|
||||
|
||||
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
|
||||
@@ -498,10 +604,27 @@ def _fused_moe_lora_expand(
|
||||
# Fast path: directly accumulate into the corresponding slice interval of output.
|
||||
out_view = output[:, :, offset : offset + num_slices * N]
|
||||
slice_c_size = N * out_view.stride(2)
|
||||
a_desc = None
|
||||
b_desc = None
|
||||
if use_tma:
|
||||
if sorted_token_ids is not None:
|
||||
a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
a_intermediate_cache1,
|
||||
[expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]],
|
||||
)
|
||||
if num_slices == 1:
|
||||
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
|
||||
lora_b_stacked[0],
|
||||
[1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]],
|
||||
)
|
||||
else:
|
||||
b_desc = None
|
||||
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
a_desc,
|
||||
b_ptr,
|
||||
b_desc,
|
||||
out_view,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
@@ -535,7 +658,8 @@ def _fused_moe_lora_expand(
|
||||
naive_block_assignment=sorted_token_ids is None,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
ADD_INPUTS=True,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
USE_B_L2_CACHE=True,
|
||||
sort_c=False,
|
||||
IS_PRIMARY=False,
|
||||
**expand_config,
|
||||
)
|
||||
@@ -559,7 +683,7 @@ def _fused_moe_lora(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@@ -616,8 +740,34 @@ def _fused_moe_lora(
|
||||
else num_tokens * shrink_block_size_m
|
||||
)
|
||||
|
||||
# TMA is not currently compatiple with fully_sharded due to the non-determinism
|
||||
# of token id sorting across ranks.
|
||||
use_tma = supports_tma(device) and not fully_sharded
|
||||
|
||||
intermediate_cache_shape = (
|
||||
num_slices,
|
||||
M,
|
||||
top_k_num,
|
||||
max_lora_rank,
|
||||
)
|
||||
if use_tma:
|
||||
if num_slices > 1:
|
||||
# if num_slices > 1, we construct TMA descriptors for LoRA
|
||||
# weights within the kernel, which requires us to first set an allocator
|
||||
set_triton_allocator(device)
|
||||
|
||||
# When storing intermediate data in sorted order for TMA, we
|
||||
# need an extra 'num_active_loras' dim in the cache to avoid conflicts
|
||||
if sorted_token_ids is not None:
|
||||
intermediate_cache_shape = (
|
||||
num_slices,
|
||||
sorted_token_ids.shape[0],
|
||||
EM,
|
||||
max_lora_rank,
|
||||
)
|
||||
|
||||
a_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, max_lora_rank),
|
||||
intermediate_cache_shape,
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
@@ -654,6 +804,7 @@ def _fused_moe_lora(
|
||||
num_active_loras,
|
||||
mul_routed_weight,
|
||||
use_gdc=use_gdc,
|
||||
use_tma=use_tma,
|
||||
)
|
||||
|
||||
if fully_sharded:
|
||||
@@ -703,6 +854,7 @@ def _fused_moe_lora(
|
||||
mul_routed_weight,
|
||||
offset,
|
||||
use_gdc=use_gdc,
|
||||
use_tma=use_tma,
|
||||
)
|
||||
|
||||
|
||||
@@ -719,7 +871,7 @@ def _fused_moe_lora_fake(
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
@@ -769,9 +921,10 @@ def _fused_moe_lora_shrink_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -805,10 +958,11 @@ def _fused_moe_lora_expand_fake(
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
mul_routed_weight: bool = False,
|
||||
offset: int = 0,
|
||||
use_gdc: bool = False,
|
||||
use_tma: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ def _lora_expand(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
@@ -235,7 +235,7 @@ def _lora_expand(
|
||||
grid = (
|
||||
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
num_active_loras,
|
||||
num_active_loras.item(),
|
||||
)
|
||||
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
|
||||
# making PDL invalid and affecting the kernel performance.
|
||||
@@ -289,7 +289,7 @@ def _lora_expand_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
offset_start: int = 0,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
|
||||
@@ -29,9 +29,16 @@ class LoRAKernelMeta:
|
||||
# to early exit from inside the lora_expand / lora_shrink torch operation.
|
||||
no_lora_flag_cpu: torch.Tensor
|
||||
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
|
||||
# Stored as a Python int to avoid GPU->CPU sync during forward pass
|
||||
num_active_loras: int = 0
|
||||
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping).
|
||||
# Stored as a CPU tensor (not a Python int) so that torch.compile treats
|
||||
# it as a dynamic value rather than baking it as a constant at trace time.
|
||||
# This follows the same pattern as no_lora_flag_cpu above.
|
||||
num_active_loras_cpu: torch.Tensor
|
||||
|
||||
# Default num_active_loras value (max_loras + 1) as a CPU tensor,
|
||||
# used when specialize_active_lora is False to avoid allocating a
|
||||
# new tensor on every meta_args() call.
|
||||
default_num_active_loras_cpu: torch.Tensor
|
||||
|
||||
# Captured LoRA counts for cudagraph specialization (sorted list).
|
||||
# When specialize_active_lora is enabled, num_active_loras is rounded up
|
||||
@@ -73,6 +80,11 @@ class LoRAKernelMeta:
|
||||
|
||||
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
|
||||
|
||||
num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu")
|
||||
default_num_active_loras_cpu = torch.tensor(
|
||||
[max_loras + 1], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
return LoRAKernelMeta(
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
|
||||
@@ -80,6 +92,8 @@ class LoRAKernelMeta:
|
||||
num_tokens_per_lora=num_tokens_per_lora,
|
||||
lora_token_start_loc=lora_token_start_loc,
|
||||
no_lora_flag_cpu=no_lora_flag_cpu,
|
||||
num_active_loras_cpu=num_active_loras_cpu,
|
||||
default_num_active_loras_cpu=default_num_active_loras_cpu,
|
||||
captured_lora_counts=sorted(captured_lora_counts)
|
||||
if captured_lora_counts
|
||||
else [],
|
||||
@@ -90,8 +104,7 @@ class LoRAKernelMeta:
|
||||
self.num_tokens_per_lora.fill_(0)
|
||||
self.lora_token_start_loc.fill_(0)
|
||||
self.no_lora_flag_cpu.fill_(False)
|
||||
self.num_active_loras = 0
|
||||
self.captured_lora_counts = []
|
||||
self.num_active_loras_cpu.fill_(0)
|
||||
|
||||
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
|
||||
"""
|
||||
@@ -137,14 +150,16 @@ class LoRAKernelMeta:
|
||||
num_tokens_per_lora, non_blocking=True
|
||||
)
|
||||
|
||||
self.num_active_loras = lora_ids.size(0)
|
||||
num_active_loras = lora_ids.size(0)
|
||||
|
||||
# Round up num_active_loras to match cudagraph capture keys.
|
||||
# This ensures the kernel grid dimension matches the captured graph.
|
||||
if self.captured_lora_counts and self.num_active_loras > 0:
|
||||
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
|
||||
if self.captured_lora_counts and num_active_loras > 0:
|
||||
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
|
||||
if idx < len(self.captured_lora_counts):
|
||||
self.num_active_loras = self.captured_lora_counts[idx]
|
||||
num_active_loras = self.captured_lora_counts[idx]
|
||||
|
||||
self.num_active_loras_cpu[0] = num_active_loras
|
||||
|
||||
# lora_token_start_loc
|
||||
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
|
||||
@@ -163,7 +178,7 @@ class LoRAKernelMeta:
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
int,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
This function returns the kernel metadata required for the current
|
||||
@@ -175,7 +190,10 @@ class LoRAKernelMeta:
|
||||
token_nums (int): Number of input tokens in the current forward
|
||||
pass of the kernel.
|
||||
"""
|
||||
max_loras = self.active_lora_ids.size(0) - 1
|
||||
if specialize_active_lora:
|
||||
num_active_loras = self.num_active_loras_cpu
|
||||
else:
|
||||
num_active_loras = self.default_num_active_loras_cpu
|
||||
return (
|
||||
self.token_lora_mapping[:token_nums],
|
||||
self.token_indices_sorted_by_lora_ids[:token_nums],
|
||||
@@ -183,5 +201,5 @@ class LoRAKernelMeta:
|
||||
self.lora_token_start_loc,
|
||||
self.active_lora_ids,
|
||||
self.no_lora_flag_cpu,
|
||||
self.num_active_loras if specialize_active_lora else max_loras + 1,
|
||||
num_active_loras,
|
||||
)
|
||||
|
||||
@@ -134,7 +134,7 @@ def _lora_shrink(
|
||||
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
|
||||
lora_ids: torch.Tensor, # shape [max-loras + 1]
|
||||
no_lora_flag_cpu: torch.Tensor, # shape [1]
|
||||
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -157,6 +157,9 @@ def _lora_shrink(
|
||||
lora_ids (torch.Tensor): LoRA ids to process.
|
||||
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
|
||||
if there are any requests that require LoRA.
|
||||
num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the
|
||||
number of active LoRAs. Stored as a tensor (not int) so
|
||||
torch.compile treats it as dynamic rather than a constant.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
|
||||
@@ -215,7 +218,7 @@ def _lora_shrink(
|
||||
grid = (
|
||||
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
NUM_SLICES,
|
||||
num_active_loras,
|
||||
num_active_loras.item(),
|
||||
)
|
||||
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
|
||||
# making PDL invalid and affecting the kernel performance.
|
||||
@@ -267,7 +270,7 @@ def _lora_shrink_fake(
|
||||
lora_token_start_loc: torch.Tensor,
|
||||
lora_ids: torch.Tensor,
|
||||
no_lora_flag_cpu: torch.Tensor,
|
||||
num_active_loras: int,
|
||||
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
|
||||
scaling: float,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool:
|
||||
and current_platform.has_device_capability(90)
|
||||
and not envs.VLLM_LORA_DISABLE_PDL
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_tma(device: torch.device | None = None) -> bool:
|
||||
# TMA requires compute capability SM90 or above
|
||||
return current_platform.is_cuda() and current_platform.has_device_capability(90)
|
||||
|
||||
@@ -233,6 +233,17 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||
|
||||
import vllm.envs as env
|
||||
if env.VLLM_USE_LORA_FUSION:
|
||||
import ixformer.inference.functions as ops
|
||||
|
||||
num_token, m = x.size(0), x.size(-1)
|
||||
k, n = lora_b_stacked[0].size(-1), y.size(-1)
|
||||
if len(lora_a_stacked) == 1 and ops.lora_gemv_optim_condition(num_token, m, k, n):
|
||||
ops.add_lora_linear(y, x, lora_a_stacked, lora_b_stacked,
|
||||
lora_bias_stacked = None, scale = 1.0, output_slices = (1,))
|
||||
return
|
||||
|
||||
assert buffer is None, (
|
||||
"To minimize overhead, the buffer should be created by "
|
||||
".add_lora_linear() instead of being passed in."
|
||||
@@ -351,6 +362,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = topk_ids.numel() * block_size
|
||||
sorted_ids = torch.empty(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
dtype=torch.int32,
|
||||
|
||||
Reference in New Issue
Block a user