update
This commit is contained in:
242
vllm/v1/worker/mamba_utils.py
Normal file
242
vllm/v1/worker/mamba_utils.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
src_ptr = tl.load(src_ptrs + pid)
|
||||
dst_ptr = tl.load(dst_ptrs + pid)
|
||||
size = tl.load(sizes + pid)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
for i in range(0, size, BLOCK_SIZE):
|
||||
mask = (i + offsets) < size
|
||||
|
||||
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
|
||||
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
|
||||
|
||||
data = tl.load(curr_src_ptr, mask=mask)
|
||||
tl.store(curr_dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
|
||||
batch = src_ptrs.shape[0]
|
||||
assert dst_ptrs.shape[0] == batch
|
||||
assert sizes.shape[0] == batch
|
||||
|
||||
grid = (batch,)
|
||||
BLOCK_SIZE = 1024
|
||||
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
|
||||
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
|
||||
mamba_group_ids: list[int] = []
|
||||
mamba_specs: list[MambaSpec] = []
|
||||
for i in range(len(kv_cache_config.kv_cache_groups)):
|
||||
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
|
||||
if isinstance(kv_cache_spec, MambaSpec):
|
||||
mamba_group_ids.append(i)
|
||||
mamba_specs.append(kv_cache_spec)
|
||||
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
|
||||
assert all(mamba_specs[0] == spec for spec in mamba_specs)
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
src_block_idx: int,
|
||||
dest_block_idx: int,
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
|
||||
for layer_name in layer_names:
|
||||
attention = forward_context[layer_name]
|
||||
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
|
||||
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
|
||||
copy_spec = state_copy_func(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
scheduler_output: SchedulerOutput,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
cache_config: CacheConfig,
|
||||
mamba_state_idx: dict[str, int],
|
||||
input_batch: GPUInputBatch,
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
(1 + num_speculative_blocks) block.
|
||||
"""
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
num_speculative_blocks = mamba_spec.num_speculative_blocks
|
||||
# TODO(Chen): we need to optimize this function a lot
|
||||
assert cache_config.enable_prefix_caching
|
||||
block_size = mamba_spec.block_size
|
||||
finished_req_ids = scheduler_output.finished_req_ids
|
||||
preempted_req_ids = scheduler_output.preempted_req_ids or set()
|
||||
# We need to clear mamba_state_idx for resumed requests. When requests are
|
||||
# force-preempted (e.g., during reset_prefix_cache / KV cache flush),
|
||||
# they appear in resumed_req_ids without a corresponding entry in
|
||||
# preempted_req_ids, leaving stale mamba_state_idx entries that can
|
||||
# point to block indices beyond the new (smaller) block allocation.
|
||||
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
if prev_state_idx is None:
|
||||
# new / resumed request, no previous state
|
||||
# if num_computed_tokens is 0, prev_state_idx will be -1
|
||||
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_blocks: int = (
|
||||
cdiv(req_state.num_computed_tokens + num_scheduled_tokens, block_size)
|
||||
+ num_speculative_blocks
|
||||
)
|
||||
|
||||
# We always save the current running state at the last
|
||||
# (1 + num_speculative_blocks) block.
|
||||
# A corner case worth mention here: assume we have block_size = 4 and
|
||||
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
|
||||
# tokens [draft 1, draft 2]. Then we will have:
|
||||
# Block 0: [A, B, C, draft 1]
|
||||
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
|
||||
# Block 2: speculative block
|
||||
# Block 3: speculative block
|
||||
# And use block 1 to save the running state.
|
||||
curr_state_idx = num_blocks - 1 - num_speculative_blocks
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
prev_state_idx,
|
||||
curr_state_idx,
|
||||
input_batch.num_accepted_tokens_cpu[i] - 1,
|
||||
req_state,
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
scheduler_output: SchedulerOutput,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
input_batch: GPUInputBatch,
|
||||
requests: dict[str, CachedRequestState],
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
state from the block for running state to the new full block.
|
||||
"""
|
||||
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
|
||||
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
|
||||
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
|
||||
num_accepted_tokens = num_accepted_tokens_cpu[i]
|
||||
num_tokens_running_state = (
|
||||
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
|
||||
)
|
||||
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
|
||||
aligned_new_computed_tokens = (
|
||||
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
|
||||
)
|
||||
# TODO: how to ensure all blocks that cache_blocks called are cached here?
|
||||
if aligned_new_computed_tokens >= num_tokens_running_state:
|
||||
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
src_block_idx,
|
||||
dest_block_idx,
|
||||
accept_token_bias,
|
||||
req_state,
|
||||
forward_context,
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
Reference in New Issue
Block a user