# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Any import torch from vllm.config import CacheConfig from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, ) from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch @triton.jit def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) src_ptr = tl.load(src_ptrs + pid) dst_ptr = tl.load(dst_ptrs + pid) size = tl.load(sizes + pid) offsets = tl.arange(0, BLOCK_SIZE) for i in range(0, size, BLOCK_SIZE): mask = (i + offsets) < size curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8)) data = tl.load(curr_src_ptr, mask=mask) tl.store(curr_dst_ptr, data, mask=mask) def batch_memcpy(src_ptrs, dst_ptrs, sizes): batch = src_ptrs.shape[0] assert dst_ptrs.shape[0] == batch assert sizes.shape[0] == batch grid = (batch,) BLOCK_SIZE = 1024 batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]: mamba_group_ids: list[int] = [] mamba_specs: list[MambaSpec] = [] for i in range(len(kv_cache_config.kv_cache_groups)): kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec if isinstance(kv_cache_spec, MambaSpec): mamba_group_ids.append(i) mamba_specs.append(kv_cache_spec) assert len(mamba_group_ids) > 0, "no mamba layers in the model" assert all(mamba_specs[0] == spec for spec in mamba_specs) return mamba_group_ids, mamba_specs[0] def collect_mamba_copy_meta( src_state_list: list[int], dest_state_list: list[int], num_elements_list: list[int], kv_cache_config: KVCacheConfig, mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_group_ids: list[int], src_block_idx: int, dest_block_idx: int, accept_token_bias: int, req_state: CachedRequestState, forward_context: dict[str, Any], ): if src_block_idx == dest_block_idx and accept_token_bias == 0: return for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] dest_block_id = block_ids[dest_block_idx] layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names for layer_name in layer_names: attention = forward_context[layer_name] kv_caches: list[torch.Tensor] = attention.kv_cache[0] for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs): copy_spec = state_copy_func( state, block_ids, src_block_idx, accept_token_bias + 1 ) src_state_list.append(copy_spec.start_addr) dest_state_list.append(state[dest_block_id].data_ptr()) num_elements_list.append(copy_spec.num_elements * state.element_size()) def do_mamba_copy_block( src_state_list: list[int], dest_state_list: list[int], num_elements_list: list[int], ): if len(src_state_list) == 0: return assert len(src_state_list) == len(dest_state_list) assert len(src_state_list) == len(num_elements_list) src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64) dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64) num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32) batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) def preprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, cache_config: CacheConfig, mamba_state_idx: dict[str, int], input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], ): """ Copy the mamba state of previous step to the last (1 + num_speculative_blocks) block. """ mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) num_speculative_blocks = mamba_spec.num_speculative_blocks # TODO(Chen): we need to optimize this function a lot assert cache_config.enable_prefix_caching block_size = mamba_spec.block_size finished_req_ids = scheduler_output.finished_req_ids preempted_req_ids = scheduler_output.preempted_req_ids or set() # We need to clear mamba_state_idx for resumed requests. When requests are # force-preempted (e.g., during reset_prefix_cache / KV cache flush), # they appear in resumed_req_ids without a corresponding entry in # preempted_req_ids, leaving stale mamba_state_idx entries that can # point to block indices beyond the new (smaller) block allocation. resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): mamba_state_idx.pop(req_id, None) src_state_list: list[int] = [] dest_state_list: list[int] = [] num_elements_list: list[int] = [] for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] prev_state_idx = mamba_state_idx.get(req_id) if prev_state_idx is None: # new / resumed request, no previous state # if num_computed_tokens is 0, prev_state_idx will be -1 prev_state_idx = (req_state.num_computed_tokens - 1) // block_size num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_blocks: int = ( cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size) + num_speculative_blocks ) # We always save the current running state at the last # (1 + num_speculative_blocks) block. # A corner case worth mention here: assume we have block_size = 4 and # num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft # tokens [draft 1, draft 2]. Then we will have: # Block 0: [A, B, C, draft 1] # Block 1: [draft 2, TOFILL, TOFILL, TOFILL] # Block 2: speculative block # Block 3: speculative block # And use block 1 to save the running state. curr_state_idx = num_blocks - 1 - num_speculative_blocks mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: collect_mamba_copy_meta( src_state_list, dest_state_list, num_elements_list, kv_cache_config, mamba_state_copy_funcs, mamba_group_ids, prev_state_idx, curr_state_idx, input_batch.num_accepted_tokens_cpu[i] - 1, req_state, forward_context, ) input_batch.num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) def postprocess_mamba( scheduler_output: SchedulerOutput, kv_cache_config: KVCacheConfig, input_batch: GPUInputBatch, requests: dict[str, CachedRequestState], mamba_state_idx: dict[str, int], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], ): """ If a blocks is converted from partial block to full block in this step, copy the state from the block for running state to the new full block. """ num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) src_state_list: list[int] = [] dest_state_list: list[int] = [] num_elements_list: list[int] = [] for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] num_computed_tokens = req_state.num_computed_tokens num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, [])) num_scheduled_tokens = num_scheduled_tokens_dict[req_id] num_accepted_tokens = num_accepted_tokens_cpu[i] num_tokens_running_state = ( num_computed_tokens + num_scheduled_tokens - num_draft_tokens ) new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1 aligned_new_computed_tokens = ( new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size ) # TODO: how to ensure all blocks that cache_blocks called are cached here? if aligned_new_computed_tokens >= num_tokens_running_state: accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 collect_mamba_copy_meta( src_state_list, dest_state_list, num_elements_list, kv_cache_config, mamba_state_copy_funcs, mamba_group_ids, src_block_idx, dest_block_idx, accept_token_bias, req_state, forward_context, ) if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)