Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
FusedMoEKernel,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
MoEPrepareAndFinalizeNoDPEPModular,
)
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
@@ -83,7 +83,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
hidden_size = layer.hidden_size
intermediate_size = layer.intermediate_size_per_partition
intermediate_size = (
self.w2_lora_a_stacked[0].shape[-1]
if op_prefix == "w2"
else self.w13_lora_b_stacked[0].shape[-2]
)
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=num_loras,
@@ -132,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
# Use the existing modular kernel from the quant method
m_fused_moe_fn = self.base_layer.quant_method.moe_mk
m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
# Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream.
m_fused_moe_fn.shared_experts = None
@@ -140,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
# Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can
# overlap them with routed experts via a separate CUDA stream.
prepare_finalize = MoEPrepareAndFinalizeNoEP()
m_fused_moe_fn = FusedMoEModularKernel(
prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
m_fused_moe_fn = FusedMoEKernel(
prepare_finalize,
self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer
@@ -150,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if quant_config.use_mxfp4_w4a16:
assert isinstance(
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
m_fused_moe_fn.impl.fused_experts,
(MarlinExperts, UnfusedOAITritonExperts),
)
else:
assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts)
assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
@@ -333,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
return wrapper
fused_experts = m_fused_moe_fn.fused_experts
fused_experts = m_fused_moe_fn.impl.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)

View File

@@ -88,10 +88,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
# TODO: Verify if this condition can be further relaxed
if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048:
raise ValueError(
"When using LoRA, vocab size must be > 32000 and <= 258048"
)
if self.base_layer.vocab_size > 258048:
raise ValueError("When using LoRA, vocab size must be <= 258048")
self.lora_a_stacked = torch.zeros(
(
max_loras,

View File

@@ -8,9 +8,10 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.triton_utils import tl, triton
from vllm.triton_utils.allocation import set_triton_allocator
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
from .utils import supports_pdl, supports_tma
@triton.jit
@@ -70,6 +71,37 @@ def _get_token_offs(
)
@triton.jit
def _get_c_ptrs(
cur_c_ptr,
lora_id,
pid_m,
offs,
offs_token,
offs_cn,
stride_cm,
stride_cn,
EM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
sort_c: tl.constexpr,
):
# When sort_c is true, store the output in c_ptr using token order defined
# in sorted_token_ids_ptr; otherwise, use the original token order from the prompt
if sort_c:
offs_token_id = pid_m * BLOCK_SIZE_M + offs
c_ptrs = (
cur_c_ptr
+ lora_id * EM * stride_cm
+ stride_cm * offs_token_id[:, None]
+ stride_cn * offs_cn[None, :]
)
else:
c_ptrs = (
cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
)
return c_ptrs
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@@ -95,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
def _adjust_kernel_inputs(
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
@@ -109,7 +141,7 @@ def _adjust_kernel_inputs(
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
grid_lora_dim = num_active_loras
grid_lora_dim = num_active_loras.item()
return grid_lora_dim, stride_tl, stride_el
@@ -125,7 +157,9 @@ def _adjust_kernel_inputs(
)
def _fused_moe_lora_kernel(
a_ptr,
a_desc,
b_ptr,
b_desc,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
@@ -177,6 +211,18 @@ def _fused_moe_lora_kernel(
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
IS_PRIMARY: tl.constexpr,
USE_TMA: tl.constexpr,
# sort_c determines whether tokens are stored in C in the order determined
# by sorted_token_ids to enable later TMA loads from this tensor.
#
# When USE_TMA is enabled, the parameter combinations are:
# a_desc | b_desc | sort_c | Use Case
# --------|---------|--------|-----------------------------
# yes | yes | False | expand kernel (num_slices=1)
# no | yes | True | shrink kernel (num_slices=1)
# yes | no | False | expand kernel (num_slices>1)
# no | no | True | shrink kernel (num_slices>1)
sort_c: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
@@ -250,58 +296,90 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs
a_ptrs = cur_a_ptr + (
offs_token[:, None] // token_mapping_factor * stride_am
+ offs_k[None, :] * stride_ak
)
if USE_TMA and a_desc is not None:
# Expand path - with TMA enabled, load from A using TMA descriptor
offs_am = (
slice_id * max_loras * EM
+ lora_id * EM
+ pid_m * BLOCK_SIZE_M // token_mapping_factor
)
offs_ak = pid_sk * BLOCK_SIZE_K
else:
# Shrink path - load hidden states based on order defined in
# 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order
tl.static_assert(a_desc is None, "a_desc must be none")
a_ptrs = cur_a_ptr + (
offs_token[:, None] // token_mapping_factor * stride_am
+ offs_k[None, :] * stride_ak
)
b_ptrs = (
cur_b_ptr
+ lora_id * stride_bl
+ expert_id * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if USE_TMA:
offs_bn = pid_n * BLOCK_SIZE_N
offs_bk = pid_sk * BLOCK_SIZE_K
if b_desc is None:
# Note(@gnovack) - Allocation of TMA descriptors on-device
# can cause conflicts when running in parallel via PDL
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
b_desc = tl.make_tensor_descriptor(
cur_b_ptr,
shape=[max_loras, num_experts, N, K],
strides=[stride_bl, stride_be, stride_bn, stride_bk],
block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K],
)
else:
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
b_ptrs = (
cur_b_ptr
+ lora_id * stride_bl
+ expert_id * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K)
k_remaining = K - cur_k_offset
# pre-fetch lora weight
# add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
if b_desc is not None:
b = (
b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset])
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
.T
)
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
# add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if a_desc is not None:
a = a_desc.load([offs_am, offs_ak + cur_k_offset])
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
@@ -309,7 +387,19 @@ def _fused_moe_lora_kernel(
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_ptrs = _get_c_ptrs(
cur_c_ptr,
lora_id,
pid_m,
offs,
offs_token,
offs_cn,
stride_cm,
stride_cn,
EM,
BLOCK_SIZE_M,
sort_c,
)
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1:
@@ -354,9 +444,10 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
use_gdc: bool = False,
use_tma: bool = False,
) -> None:
w1_lora_a_stacked = lora_a_stacked[0]
shrink_config = {
@@ -369,6 +460,7 @@ def _fused_moe_lora_shrink(
"SPLIT_K": split_k,
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
"USE_TMA": use_tma,
}
b_ptr = _get_ptr(lora_a_stacked, device)
@@ -383,9 +475,20 @@ def _fused_moe_lora_shrink(
len(lora_a_stacked),
grid_lora_dim,
)
a_desc = None
b_desc = None
if use_tma and num_slices == 1:
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
lora_a_stacked[0],
[1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]],
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
a_desc,
b_ptr,
b_desc,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
@@ -407,8 +510,8 @@ def _fused_moe_lora_shrink(
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
a_intermediate_cache1.stride(-2),
a_intermediate_cache1.stride(-1),
stride_tl,
stride_el,
slice_a_size=qcurr_hidden_states.numel(),
@@ -419,7 +522,8 @@ def _fused_moe_lora_shrink(
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False,
USE_B_L2_CACHE=True, # new
USE_B_L2_CACHE=True,
sort_c=use_tma and sorted_token_ids is not None,
IS_PRIMARY=True,
**shrink_config,
)
@@ -458,10 +562,11 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
use_tma: bool = False,
) -> None:
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
@@ -470,7 +575,7 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked = lora_b_stacked[0]
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
-1, a_intermediate_cache1.shape[-1]
)
expand_config = {
@@ -483,6 +588,7 @@ def _fused_moe_lora_expand(
"SPLIT_K": 1, # Set split_k = 1 for expand calls
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
"USE_TMA": use_tma,
}
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
@@ -498,10 +604,27 @@ def _fused_moe_lora_expand(
# Fast path: directly accumulate into the corresponding slice interval of output.
out_view = output[:, :, offset : offset + num_slices * N]
slice_c_size = N * out_view.stride(2)
a_desc = None
b_desc = None
if use_tma:
if sorted_token_ids is not None:
a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
a_intermediate_cache1,
[expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]],
)
if num_slices == 1:
b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
lora_b_stacked[0],
[1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]],
)
else:
b_desc = None
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
a_desc,
b_ptr,
b_desc,
out_view,
topk_weights,
sorted_token_ids,
@@ -535,7 +658,8 @@ def _fused_moe_lora_expand(
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True,
USE_B_L2_CACHE=True, # new
USE_B_L2_CACHE=True,
sort_c=False,
IS_PRIMARY=False,
**expand_config,
)
@@ -559,7 +683,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@@ -616,8 +740,34 @@ def _fused_moe_lora(
else num_tokens * shrink_block_size_m
)
# TMA is not currently compatiple with fully_sharded due to the non-determinism
# of token id sorting across ranks.
use_tma = supports_tma(device) and not fully_sharded
intermediate_cache_shape = (
num_slices,
M,
top_k_num,
max_lora_rank,
)
if use_tma:
if num_slices > 1:
# if num_slices > 1, we construct TMA descriptors for LoRA
# weights within the kernel, which requires us to first set an allocator
set_triton_allocator(device)
# When storing intermediate data in sorted order for TMA, we
# need an extra 'num_active_loras' dim in the cache to avoid conflicts
if sorted_token_ids is not None:
intermediate_cache_shape = (
num_slices,
sorted_token_ids.shape[0],
EM,
max_lora_rank,
)
a_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, max_lora_rank),
intermediate_cache_shape,
dtype=output.dtype,
device=device,
)
@@ -654,6 +804,7 @@ def _fused_moe_lora(
num_active_loras,
mul_routed_weight,
use_gdc=use_gdc,
use_tma=use_tma,
)
if fully_sharded:
@@ -703,6 +854,7 @@ def _fused_moe_lora(
mul_routed_weight,
offset,
use_gdc=use_gdc,
use_tma=use_tma,
)
@@ -719,7 +871,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@@ -769,9 +921,10 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
use_gdc: bool = False,
use_tma: bool = False,
) -> None:
return
@@ -805,10 +958,11 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
use_tma: bool = False,
) -> None:
return

View File

@@ -138,7 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
@@ -235,7 +235,7 @@ def _lora_expand(
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
num_active_loras,
num_active_loras.item(),
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
@@ -289,7 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0,
add_inputs: bool = False,
) -> None:

View File

@@ -29,9 +29,16 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping).
# Stored as a CPU tensor (not a Python int) so that torch.compile treats
# it as a dynamic value rather than baking it as a constant at trace time.
# This follows the same pattern as no_lora_flag_cpu above.
num_active_loras_cpu: torch.Tensor
# Default num_active_loras value (max_loras + 1) as a CPU tensor,
# used when specialize_active_lora is False to avoid allocating a
# new tensor on every meta_args() call.
default_num_active_loras_cpu: torch.Tensor
# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
@@ -73,6 +80,11 @@ class LoRAKernelMeta:
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu")
default_num_active_loras_cpu = torch.tensor(
[max_loras + 1], dtype=torch.int32, device="cpu"
)
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
@@ -80,6 +92,8 @@ class LoRAKernelMeta:
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
num_active_loras_cpu=num_active_loras_cpu,
default_num_active_loras_cpu=default_num_active_loras_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
@@ -90,8 +104,7 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
self.captured_lora_counts = []
self.num_active_loras_cpu.fill_(0)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
@@ -137,14 +150,16 @@ class LoRAKernelMeta:
num_tokens_per_lora, non_blocking=True
)
self.num_active_loras = lora_ids.size(0)
num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
if self.captured_lora_counts and num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx]
num_active_loras = self.captured_lora_counts[idx]
self.num_active_loras_cpu[0] = num_active_loras
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
@@ -163,7 +178,7 @@ class LoRAKernelMeta:
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
torch.Tensor,
]:
"""
This function returns the kernel metadata required for the current
@@ -175,7 +190,10 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
max_loras = self.active_lora_ids.size(0) - 1
if specialize_active_lora:
num_active_loras = self.num_active_loras_cpu
else:
num_active_loras = self.default_num_active_loras_cpu
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
@@ -183,5 +201,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1,
num_active_loras,
)

View File

@@ -134,7 +134,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float,
) -> None:
"""
@@ -157,6 +157,9 @@ def _lora_shrink(
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the
number of active LoRAs. Stored as a tensor (not int) so
torch.compile treats it as dynamic rather than a constant.
scaling (float): Scaling factor.
"""
@@ -215,7 +218,7 @@ def _lora_shrink(
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
num_active_loras,
num_active_loras.item(),
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
@@ -267,7 +270,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float,
) -> None:
return

View File

@@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool:
and current_platform.has_device_capability(90)
and not envs.VLLM_LORA_DISABLE_PDL
)
@lru_cache
def supports_tma(device: torch.device | None = None) -> bool:
# TMA requires compute capability SM90 or above
return current_platform.is_cuda() and current_platform.has_device_capability(90)

View File

@@ -233,6 +233,17 @@ class PunicaWrapperGPU(PunicaWrapperBase):
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
import vllm.envs as env
if env.VLLM_USE_LORA_FUSION:
import ixformer.inference.functions as ops
num_token, m = x.size(0), x.size(-1)
k, n = lora_b_stacked[0].size(-1), y.size(-1)
if len(lora_a_stacked) == 1 and ops.lora_gemv_optim_condition(num_token, m, k, n):
ops.add_lora_linear(y, x, lora_a_stacked, lora_b_stacked,
lora_bias_stacked = None, scale = 1.0, output_slices = (1,))
return
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
@@ -351,6 +362,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
if topk_ids.numel() < num_experts:
max_num_tokens_padded = topk_ids.numel() * block_size
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,