Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MambaCopyBuffers:
|
||||
src_ptrs: CpuGpuBuffer
|
||||
dst_ptrs: CpuGpuBuffer
|
||||
sizes: CpuGpuBuffer
|
||||
offset: int = 0
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
max_num_reqs: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
make_buffer: Callable[..., CpuGpuBuffer],
|
||||
) -> "MambaCopyBuffers":
|
||||
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
|
||||
entries_per_req = sum(
|
||||
len(kv_cache_config.kv_cache_groups[gid].layer_names)
|
||||
for gid in mamba_group_ids
|
||||
) * len(copy_funcs)
|
||||
n = max_num_reqs * entries_per_req
|
||||
return cls(
|
||||
src_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
dst_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
sizes=make_buffer(n, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
src_ptrs_np = copy_bufs.src_ptrs.np
|
||||
dst_ptrs_np = copy_bufs.dst_ptrs.np
|
||||
sizes_np = copy_bufs.sizes.np
|
||||
offset = copy_bufs.offset
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
src_ptrs_np[offset] = copy_spec.start_addr
|
||||
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
|
||||
sizes_np[offset] = copy_spec.num_elements * state.element_size()
|
||||
offset += 1
|
||||
|
||||
copy_bufs.offset = offset
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
|
||||
n = copy_bufs.offset
|
||||
if n == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
batch_memcpy(
|
||||
copy_bufs.src_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.dst_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.sizes.copy_to_gpu(n),
|
||||
)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
@@ -117,6 +149,7 @@ def preprocess_mamba(
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
@@ -138,9 +171,7 @@ def preprocess_mamba(
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
@@ -169,9 +200,7 @@ def preprocess_mamba(
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -182,7 +211,7 @@ def preprocess_mamba(
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
@@ -193,6 +222,7 @@ def postprocess_mamba(
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
@@ -203,9 +233,7 @@ def postprocess_mamba(
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
@@ -225,9 +253,7 @@ def postprocess_mamba(
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -239,4 +265,4 @@ def postprocess_mamba(
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
Reference in New Issue
Block a user