Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -2,7 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import product as iprod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,13 +15,208 @@ from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import largest_power_of_2_divisor
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _zero_kv_blocks_kernel(
|
||||
seg_addrs_ptr,
|
||||
block_ids_ptr,
|
||||
n_blocks,
|
||||
N_SEGS: tl.constexpr,
|
||||
PAGE_SIZE_EL: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Zero KV cache blocks across all segments in a single launch.
|
||||
|
||||
Each segment is a contiguous region of one block's data. For backends
|
||||
where blocks are outermost (block_dim=0) there is one segment per
|
||||
buffer. For backends where K/V is outermost (block_dim=1) there are
|
||||
two segments per buffer (one for K, one for V).
|
||||
|
||||
seg_addrs_ptr holds absolute byte addresses (int64) for each segment,
|
||||
allowing segments to live in different CUDA allocations.
|
||||
|
||||
Programs are mapped as (block_index, seg_index, chunk_index).
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
chunks = PAGE_SIZE_EL // BLOCK_SIZE
|
||||
work_per_block = N_SEGS * chunks
|
||||
block_index = pid // work_per_block
|
||||
if block_index >= n_blocks:
|
||||
return
|
||||
remainder = pid % work_per_block
|
||||
seg_index = remainder // chunks
|
||||
chunk_index = remainder % chunks
|
||||
block_id = tl.load(block_ids_ptr + block_index)
|
||||
seg_addr = tl.load(seg_addrs_ptr + seg_index)
|
||||
ptr = tl.cast(seg_addr, tl.pointer_type(tl.int32))
|
||||
offset = (
|
||||
block_id.to(tl.int64) * PAGE_SIZE_EL + chunk_index.to(tl.int64) * BLOCK_SIZE
|
||||
)
|
||||
cols = tl.arange(0, BLOCK_SIZE).to(tl.int64)
|
||||
tl.store(ptr + offset + cols, tl.zeros([BLOCK_SIZE], dtype=tl.int32))
|
||||
|
||||
|
||||
class KVBlockZeroer:
|
||||
"""Manages efficient zeroing of KV cache blocks via a Triton kernel.
|
||||
|
||||
Call :meth:`init_meta` once after KV caches are allocated to precompute
|
||||
segment addresses, then call :meth:`zero_block_ids` each step to zero
|
||||
newly-allocated blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self._meta: tuple[torch.Tensor, int, int, int] | None = None
|
||||
self._id_cap: int = 0
|
||||
self._ids_pinned: torch.Tensor | None = None
|
||||
self._ids_gpu: torch.Tensor | None = None
|
||||
|
||||
def init_meta(
|
||||
self,
|
||||
attn_groups_iter: Iterable["AttentionGroup"],
|
||||
kernel_block_sizes: list[int],
|
||||
cache_dtype: str,
|
||||
runner_only_attn_layers: set[str],
|
||||
static_forward_context: dict[str, Any],
|
||||
) -> None:
|
||||
"""One-time precomputation for zero_block_ids.
|
||||
|
||||
Builds absolute-address table for the Triton zeroing kernel.
|
||||
Each entry is the absolute byte address of a segment start on the
|
||||
GPU, so segments in different CUDA allocations work correctly.
|
||||
|
||||
Block IDs from the scheduler reference logical blocks whose size
|
||||
may differ from the kernel block size (virtual block splitting).
|
||||
PAGE_SIZE_EL accounts for this ratio so that
|
||||
``block_id * PAGE_SIZE_EL`` lands at the correct offset.
|
||||
|
||||
Only AttentionSpec layers are processed; Mamba layers are skipped.
|
||||
"""
|
||||
seen_ptrs: set[int] = set()
|
||||
seg_addrs: list[int] = []
|
||||
page_size_el: int | None = None
|
||||
|
||||
for group in attn_groups_iter:
|
||||
spec = group.kv_cache_spec
|
||||
if type(spec) is not FullAttentionSpec:
|
||||
continue
|
||||
if group.kv_cache_group_id >= len(kernel_block_sizes):
|
||||
continue
|
||||
kernel_bs = kernel_block_sizes[group.kv_cache_group_id]
|
||||
ratio = spec.block_size // kernel_bs
|
||||
block_dim = group.backend.get_kv_cache_block_dim(
|
||||
kernel_bs,
|
||||
spec.num_kv_heads,
|
||||
spec.head_size,
|
||||
cache_dtype_str=cache_dtype,
|
||||
)
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner_only_attn_layers:
|
||||
continue
|
||||
kv = static_forward_context[layer_name].kv_cache[0]
|
||||
if isinstance(kv, list):
|
||||
continue
|
||||
dp = kv.data_ptr()
|
||||
if dp in seen_ptrs:
|
||||
continue
|
||||
seen_ptrs.add(dp)
|
||||
|
||||
el = kv.element_size()
|
||||
cur_bytes = kv.stride(block_dim) * el
|
||||
assert cur_bytes % 4 == 0
|
||||
kernel_block_el = cur_bytes // 4
|
||||
cur_page_el = kernel_block_el * ratio
|
||||
if page_size_el is None:
|
||||
page_size_el = cur_page_el
|
||||
else:
|
||||
assert page_size_el == cur_page_el, (
|
||||
f"Non-uniform page sizes: {page_size_el} vs {cur_page_el}"
|
||||
)
|
||||
|
||||
block_stride_bytes = cur_bytes
|
||||
outer_dims = [
|
||||
d
|
||||
for d in range(block_dim)
|
||||
if kv.stride(d) * el > block_stride_bytes
|
||||
]
|
||||
outer_strides = [kv.stride(d) * el for d in outer_dims]
|
||||
for outer in iprod(*(range(kv.shape[d]) for d in outer_dims)):
|
||||
off_bytes = sum(i * s for i, s in zip(outer, outer_strides))
|
||||
seg_addrs.append(dp + off_bytes)
|
||||
|
||||
if not seg_addrs or page_size_el is None:
|
||||
self._meta = None
|
||||
return
|
||||
|
||||
blk_size = min(largest_power_of_2_divisor(page_size_el), 1024)
|
||||
self._id_cap = 8192
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(self._id_cap, dtype=torch.int64, device=self.device)
|
||||
self._meta = (
|
||||
torch.tensor(seg_addrs, dtype=torch.int64, device=self.device),
|
||||
page_size_el,
|
||||
blk_size,
|
||||
len(seg_addrs),
|
||||
)
|
||||
|
||||
def zero_block_ids(self, block_ids: list[int]) -> None:
|
||||
"""Zero the KV cache memory for the given block IDs."""
|
||||
if not block_ids or self._meta is None:
|
||||
return
|
||||
seg_addrs, page_size_el, blk_size, n_segs = self._meta
|
||||
n_blocks = len(block_ids)
|
||||
if n_blocks > self._id_cap:
|
||||
self._id_cap = n_blocks * 2
|
||||
self._ids_pinned = torch.empty(
|
||||
self._id_cap,
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self._ids_gpu = torch.empty(
|
||||
self._id_cap, dtype=torch.int64, device=self.device
|
||||
)
|
||||
assert self._ids_pinned is not None and self._ids_gpu is not None
|
||||
self._ids_pinned[:n_blocks].numpy()[:] = block_ids
|
||||
idx = self._ids_gpu[:n_blocks]
|
||||
idx.copy_(self._ids_pinned[:n_blocks], non_blocking=True)
|
||||
grid = (n_blocks * n_segs * (page_size_el // blk_size),)
|
||||
_zero_kv_blocks_kernel[grid](
|
||||
seg_addrs,
|
||||
idx,
|
||||
n_blocks,
|
||||
N_SEGS=n_segs,
|
||||
PAGE_SIZE_EL=page_size_el,
|
||||
BLOCK_SIZE=blk_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
@@ -36,7 +234,7 @@ class AttentionGroup:
|
||||
self,
|
||||
vllm_config,
|
||||
device,
|
||||
kernel_block_size: int | None,
|
||||
kernel_block_size: int | None = None,
|
||||
num_metadata_builders: int = 1,
|
||||
):
|
||||
kv_cache_spec_builder = (
|
||||
@@ -59,6 +257,119 @@ class AttentionGroup:
|
||||
return self.metadata_builders[ubatch_id]
|
||||
|
||||
|
||||
def select_common_block_size(
|
||||
kv_manager_block_size: int, attn_groups: list[AttentionGroup]
|
||||
) -> int:
|
||||
"""
|
||||
Select a block size that is supported by all backends and is a factor of
|
||||
kv_manager_block_size.
|
||||
|
||||
If kv_manager_block_size is supported by all backends, return it directly.
|
||||
Otherwise, return the max supported size.
|
||||
|
||||
Args:
|
||||
kv_manager_block_size: Block size of KV cache.
|
||||
attn_groups: List of attention groups.
|
||||
|
||||
Returns:
|
||||
The selected block size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid block size found.
|
||||
"""
|
||||
|
||||
def block_size_is_supported(
|
||||
backends: list[type[AttentionBackend]], block_size: int
|
||||
) -> bool:
|
||||
"""Check if the block size is supported by all backends."""
|
||||
for backend in backends:
|
||||
is_supported = False
|
||||
for supported_size in backend.get_supported_kernel_block_sizes():
|
||||
if isinstance(supported_size, int):
|
||||
if block_size == supported_size:
|
||||
is_supported = True
|
||||
elif isinstance(supported_size, MultipleOf):
|
||||
if block_size % supported_size.base == 0:
|
||||
is_supported = True
|
||||
else:
|
||||
raise ValueError(f"Unknown supported size: {supported_size}")
|
||||
if not is_supported:
|
||||
return False
|
||||
return True
|
||||
|
||||
backends = [group.backend for group in attn_groups]
|
||||
|
||||
# Case 1: if the block_size of kv cache manager is supported by all backends,
|
||||
# return it directly.
|
||||
if block_size_is_supported(backends, kv_manager_block_size):
|
||||
return kv_manager_block_size
|
||||
|
||||
# Case 2: otherwise, the block_size must be an `int`-format supported size of
|
||||
# at least one backend. Iterate over all `int`-format supported sizes in
|
||||
# descending order and return the first one that is supported by all backends.
|
||||
# Simple proof:
|
||||
# If the supported size b is in MultipleOf(x_i) format for all attention
|
||||
# backends i, and b a factor of kv_manager_block_size, then
|
||||
# kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
|
||||
# return kv_manager_block_size in case 1.
|
||||
all_int_supported_sizes = set(
|
||||
supported_size
|
||||
for backend in backends
|
||||
for supported_size in backend.get_supported_kernel_block_sizes()
|
||||
if isinstance(supported_size, int)
|
||||
)
|
||||
|
||||
for supported_size in sorted(all_int_supported_sizes, reverse=True):
|
||||
if kv_manager_block_size % supported_size != 0:
|
||||
continue
|
||||
if block_size_is_supported(backends, supported_size):
|
||||
return supported_size
|
||||
raise ValueError(f"No common block size for {kv_manager_block_size}. ")
|
||||
|
||||
|
||||
def prepare_kernel_block_sizes(
|
||||
kv_cache_config: KVCacheConfig, attn_groups: list[list[AttentionGroup]]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Generate kernel_block_sizes that matches each block_size.
|
||||
|
||||
For attention backends that support virtual block splitting,
|
||||
use the supported block sizes from the backend.
|
||||
For other backends (like Mamba), use the same block size (no splitting).
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
attn_groups: Attention groups indexed by KV cache group id.
|
||||
|
||||
Returns:
|
||||
List of kernel block sizes for each cache group.
|
||||
"""
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
# pick an arbitrary one to dispatch.
|
||||
kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
|
||||
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual block splitting.
|
||||
kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
|
||||
selected_kernel_size = select_common_block_size(
|
||||
kv_manager_block_size, attn_groups[kv_cache_gid]
|
||||
)
|
||||
kernel_block_sizes.append(selected_kernel_size)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
# This is likely Mamba or other non-attention cache, no splitting.
|
||||
kernel_block_sizes.append(kv_cache_spec.block_size)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
|
||||
)
|
||||
return kernel_block_sizes
|
||||
|
||||
|
||||
def sanity_check_mm_encoder_outputs(
|
||||
mm_embeddings: MultiModalEmbeddings,
|
||||
expected_num_items: int,
|
||||
@@ -201,6 +512,55 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
def bind_kv_cache_scale(
|
||||
kv_caches_scale: dict[str, torch.Tensor],
|
||||
forward_context: dict[str, "Attention"],
|
||||
runner_kv_caches_scale: list[torch.Tensor],
|
||||
num_attn_module: int | None = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Bind the allocated KV cache to both ModelRunner and forward context so
|
||||
that the KV cache can be used in the forward pass.
|
||||
|
||||
This function:
|
||||
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
|
||||
kv_caches.
|
||||
2) Associates each attention layer in the `forward_context` with its
|
||||
corresponding KV cache in kv_caches.
|
||||
|
||||
Args:
|
||||
kv_caches: The allocated kv_caches with layer names as keys.
|
||||
forward_context: The global forward context containing all Attention
|
||||
layers with layer names as keys.
|
||||
runner_kv_caches: The kv_cache declared by ModelRunner.
|
||||
"""
|
||||
# Bind kv_caches to ModelRunner
|
||||
assert len(runner_kv_caches_scale) == 0
|
||||
|
||||
# Convert kv_caches dict to a list of tensors in the order of layer_index.
|
||||
index2name = defaultdict(list)
|
||||
for layer_name in kv_caches_scale:
|
||||
index2name[extract_layer_index(layer_name,
|
||||
num_attn_module)].append(layer_name)
|
||||
|
||||
for layer_index in sorted(index2name.keys()):
|
||||
layer_names = index2name[layer_index]
|
||||
if len(layer_names) > 1:
|
||||
# One typical case is encoder-decoder model, e.g., bart.
|
||||
# The cross attention and self attention in the same decoder layer
|
||||
# has different layer_name but the same layer_index.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError
|
||||
layer_name = layer_names[0]
|
||||
runner_kv_caches_scale.append(kv_caches_scale[layer_name])
|
||||
|
||||
# Bind kv_caches to forward context
|
||||
for layer_name, kv_cache_scale in kv_caches_scale.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache_scale = [kv_cache_scale]
|
||||
|
||||
|
||||
def is_residual_scattered_for_sp(
|
||||
|
||||
Reference in New Issue
Block a user