# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import jax from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu from vllm.utils.math_utils import cdiv def _kv_cache_update_kernel( # Prefetch slices_ref, # [3, padded_num_slices], list of (kv_cache_start, # new_kv_start, slice_len) num_slices_ref, # [1] # Input new_kv_hbm_ref, # [num_tokens, num_combined_kv_heads, head_dim] kv_cache_hbm_ref, # [total_num_pages * page_size, num_combined_kv_heads, # head_dim] # Output _, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] # Scratch scratch, # [num_slices_per_block, page_size, num_combined_kv_heads, # head_dim] sem, ): async_copies = [] block_idx = pl.program_id(0) num_slices_per_block = scratch.shape[0] # Copy from new_kv_hbm_ref to scratch for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block new_kv_start = jax.lax.select( offset_i < num_slices_ref[0], slices_ref[1, offset_i], 0 ) length = jax.lax.select( offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 ) async_copy = pltpu.make_async_copy( new_kv_hbm_ref.at[pl.ds(new_kv_start, length), ...], scratch.at[i, pl.ds(0, length), ...], sem, ) async_copy.start() async_copies.append(async_copy) for async_copy in async_copies: async_copy.wait() # Copy from scratch to kv_cache_hbm_ref async_copies.clear() for i in range(num_slices_per_block): offset_i = i + block_idx * num_slices_per_block kv_cache_start = jax.lax.select( offset_i < num_slices_ref[0], slices_ref[0, offset_i], 0 ) length = jax.lax.select( offset_i < num_slices_ref[0], slices_ref[2, offset_i], 0 ) async_copy = pltpu.make_async_copy( scratch.at[i, pl.ds(0, length), ...], kv_cache_hbm_ref.at[pl.ds(kv_cache_start, length), ...], sem, ) async_copy.start() async_copies.append(async_copy) for async_copy in async_copies: async_copy.wait() @functools.partial( jax.jit, static_argnames=["page_size", "num_slices_per_block"], ) def kv_cache_update( # [total_num_token, num_combined_kv_heads, head_dim] new_kv: jax.Array, # [3, slices], list of (kv_cache_start, new_kv_start, slice_len) slices: jax.Array, # [total_num_pages * page_size, num_combined_kv_heads, head_dim] kv_cache: jax.Array, # [1] num_kv_update_slices: jax.Array, *, page_size: int = 32, num_slices_per_block: int = 8, ): _, num_combined_kv_heads, head_dim = new_kv.shape assert kv_cache.shape[1] == num_combined_kv_heads assert kv_cache.shape[2] == head_dim assert head_dim % 128 == 0 # TODO: Add dynamic check to make sure that the all the slice lengths are # smaller or equal to page_size in_specs = [ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), ] out_specs = [pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)] out_shape = [jax.ShapeDtypeStruct(kv_cache.shape, dtype=kv_cache.dtype)] scalar_prefetches = [slices, num_kv_update_slices] scratch = pltpu.VMEM( (num_slices_per_block, page_size, num_combined_kv_heads, head_dim), new_kv.dtype, ) scratch_shapes = [ scratch, pltpu.SemaphoreType.DMA, ] kernel = pl.pallas_call( _kv_cache_update_kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=len(scalar_prefetches), in_specs=in_specs, out_specs=out_specs, grid=(cdiv(num_kv_update_slices[0], num_slices_per_block),), scratch_shapes=scratch_shapes, ), out_shape=out_shape, input_output_aliases={len(scalar_prefetches) + 1: 0}, ) return kernel(*scalar_prefetches, new_kv, kv_cache)[0]