Files
enginex-bi_150-vllm/vllm/v1/worker/mamba_utils.py
2026-04-09 11:23:47 +08:00

243 lines
9.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Any
import torch
from vllm.config import CacheConfig
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateCopyFunc,
)
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@triton.jit
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
src_ptr = tl.load(src_ptrs + pid)
dst_ptr = tl.load(dst_ptrs + pid)
size = tl.load(sizes + pid)
offsets = tl.arange(0, BLOCK_SIZE)
for i in range(0, size, BLOCK_SIZE):
mask = (i + offsets) < size
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
data = tl.load(curr_src_ptr, mask=mask)
tl.store(curr_dst_ptr, data, mask=mask)
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch
grid = (batch,)
BLOCK_SIZE = 1024
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
mamba_group_ids: list[int] = []
mamba_specs: list[MambaSpec] = []
for i in range(len(kv_cache_config.kv_cache_groups)):
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
if isinstance(kv_cache_spec, MambaSpec):
mamba_group_ids.append(i)
mamba_specs.append(kv_cache_spec)
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
assert all(mamba_specs[0] == spec for spec in mamba_specs)
return mamba_group_ids, mamba_specs[0]
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
src_block_idx: int,
dest_block_idx: int,
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
for layer_name in layer_names:
attention = forward_context[layer_name]
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
copy_spec = state_copy_func(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
def preprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
cache_config: CacheConfig,
mamba_state_idx: dict[str, int],
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
Copy the mamba state of previous step to the last
(1 + num_speculative_blocks) block.
"""
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
num_speculative_blocks = mamba_spec.num_speculative_blocks
# TODO(Chen): we need to optimize this function a lot
assert cache_config.enable_prefix_caching
block_size = mamba_spec.block_size
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids or set()
# We need to clear mamba_state_idx for resumed requests. When requests are
# force-preempted (e.g., during reset_prefix_cache / KV cache flush),
# they appear in resumed_req_ids without a corresponding entry in
# preempted_req_ids, leaving stale mamba_state_idx entries that can
# point to block indices beyond the new (smaller) block allocation.
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
if prev_state_idx is None:
# new / resumed request, no previous state
# if num_computed_tokens is 0, prev_state_idx will be -1
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_blocks: int = (
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
+ num_speculative_blocks
)
# We always save the current running state at the last
# (1 + num_speculative_blocks) block.
# A corner case worth mention here: assume we have block_size = 4 and
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
# tokens [draft 1, draft 2]. Then we will have:
# Block 0: [A, B, C, draft 1]
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
# Block 2: speculative block
# Block 3: speculative block
# And use block 1 to save the running state.
curr_state_idx = num_blocks - 1 - num_speculative_blocks
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
prev_state_idx,
curr_state_idx,
input_batch.num_accepted_tokens_cpu[i] - 1,
req_state,
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
def postprocess_mamba(
scheduler_output: SchedulerOutput,
kv_cache_config: KVCacheConfig,
input_batch: GPUInputBatch,
requests: dict[str, CachedRequestState],
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
):
"""
If a blocks is converted from partial block to full block in this step, copy the
state from the block for running state to the new full block.
"""
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
num_accepted_tokens = num_accepted_tokens_cpu[i]
num_tokens_running_state = (
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
)
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
aligned_new_computed_tokens = (
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
)
# TODO: how to ensure all blocks that cache_blocks called are cached here?
if aligned_new_computed_tokens >= num_tokens_running_state:
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
src_block_idx,
dest_block_idx,
accept_token_bias,
req_state,
forward_context,
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)